From 45eacb782d438686e4580fe94cab5d2fb953b200 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 31 Mar 2023 13:17:45 +0200 Subject: [PATCH] patch qkv_rot --- Cargo.lock | 794 ++++++++++-------- router/src/main.rs | 2 +- .../custom_modeling/flash_llama_modeling.py | 204 +---- .../custom_modeling/flash_neox_modeling.py | 198 +---- .../models/custom_modeling/linear.py | 22 + .../models/custom_modeling/rotary.py | 42 + .../models/custom_modeling/tensor_parallel.py | 124 +++ .../models/flash_llama.py | 2 + .../models/flash_neox.py | 6 +- 9 files changed, 683 insertions(+), 711 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/linear.py create mode 100644 server/text_generation_server/models/custom_modeling/rotary.py create mode 100644 server/text_generation_server/models/custom_modeling/tensor_parallel.py diff --git a/Cargo.lock b/Cargo.lock index 8e469f23..ca60f997 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,18 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aes" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", + "opaque-debug", +] + [[package]] name = "ahash" version = "0.7.6" @@ -29,19 +41,50 @@ dependencies = [ ] [[package]] -name = "ansi_term" -version = "0.12.1" +name = "anstream" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +checksum = "342258dd14006105c2b75ab1bd7543a03bdf0cfc94383303ac212a04939dff6f" dependencies = [ - "winapi", + "anstyle", + "anstyle-parse", + "anstyle-wincon", + "concolor-override", + "concolor-query", + "is-terminal", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23ea9e81bd02e310c216d080f6223c179012256e5151c41db88d12c88a1684d2" + +[[package]] +name = "anstyle-parse" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7d1bb534e9efed14f3e5f44e7dd1a4f709384023a4165199a4241e18dff0116" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-wincon" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3127af6145b149f3287bb9a0d10ad9c5692dba8c53ad48285e5bec4063834fa" +dependencies = [ + "anstyle", + "windows-sys 0.45.0", ] [[package]] name = "anyhow" -version = "1.0.69" +version = "1.0.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" +checksum = "7de8ce5e0f9f8d88245311066a578d72b7af3e7088f32783804676302df237e4" [[package]] name = "async-stream" @@ -62,29 +105,18 @@ checksum = "e4655ae1a7b0cdf149156f780c5bf3f1352bc53cbd9e0a361a7ef7b22947e965" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] name = "async-trait" -version = "0.1.64" +version = "0.1.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cd7fce9ba8c3c042128ce72d8b2ddbf3a05747efb67ea0313c635e10bda47a2" +checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" dependencies = [ "proc-macro2", "quote", - "syn", -] - -[[package]] -name = "atty" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" -dependencies = [ - "hermit-abi 0.1.19", - "libc", - "winapi", + "syn 2.0.12", ] [[package]] @@ -95,9 +127,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "axum" -version = "0.6.9" +version = "0.6.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6137c6234afb339e75e764c866e3594900f0211e1315d33779f269bbe2ec6967" +checksum = "349f8ccfd9221ee7d1f3d4b33e1f8319b3a81ed8f61f2ea40b37b859794b4491" dependencies = [ "async-trait", "axum-core", @@ -121,16 +153,15 @@ dependencies = [ "sync_wrapper", "tokio", "tower", - "tower-http 0.4.0", "tower-layer", "tower-service", ] [[package]] name = "axum-core" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34" +checksum = "b2f958c80c248b34b9a877a643811be8dbca03ca5ba827f2b63baf3a81e5fc4e" dependencies = [ "async-trait", "bytes", @@ -154,7 +185,7 @@ dependencies = [ "http", "opentelemetry", "tower", - "tower-http 0.3.5", + "tower-http", "tracing", "tracing-opentelemetry", ] @@ -171,6 +202,12 @@ version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a" +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bitflags" version = "1.3.2" @@ -179,9 +216,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "block-buffer" -version = "0.10.3" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cce20737498f97b993470a6e536b8523f0af7892a4f928cceb1ac5e52ebe7e" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" dependencies = [ "generic-array", ] @@ -227,9 +264,9 @@ dependencies = [ [[package]] name = "cached-path" -version = "0.5.3" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f1c56d30236522ab3393a08746b138d4e16372001f42d29c88d513aeb8ab7ef" +checksum = "097968e38f1319207f057d0f4d76452e4f4f847a5de61c5215379f297fa034f3" dependencies = [ "flate2", "fs2", @@ -244,8 +281,7 @@ dependencies = [ "tar", "tempfile", "thiserror", - "zip 0.5.13", - "zip-extensions", + "zip", ] [[package]] @@ -253,6 +289,9 @@ name = "cc" version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +dependencies = [ + "jobserver", +] [[package]] name = "cfg-if" @@ -261,55 +300,69 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] -name = "clap" -version = "2.34.0" +name = "cipher" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7" dependencies = [ - "ansi_term", - "atty", - "bitflags", - "strsim 0.8.0", - "textwrap", - "unicode-width", - "vec_map", + "generic-array", ] [[package]] name = "clap" -version = "4.1.8" +version = "4.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3d7ae14b20b94cb02149ed21a86c423859cbe18dc7ed69845cace50e52b40a5" +checksum = "046ae530c528f252094e4a77886ee1374437744b2bff1497aa898bbddbbb29b3" dependencies = [ - "bitflags", + "clap_builder", "clap_derive", - "clap_lex", - "is-terminal", "once_cell", - "strsim 0.10.0", - "termcolor", +] + +[[package]] +name = "clap_builder" +version = "4.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "223163f58c9a40c3b0a43e1c4b50a9ce09f007ea2cb1ec258a687945b4b7929f" +dependencies = [ + "anstream", + "anstyle", + "bitflags", + "clap_lex", + "strsim", ] [[package]] name = "clap_derive" -version = "4.1.8" +version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44bec8e5c9d09e439c4335b1af0abaab56dcf3b94999a936e1bb47b9134288f0" +checksum = "3f9644cd56d6b87dbe899ef8b053e331c0637664e9e21a33dfcdc36093f5c5c4" dependencies = [ "heck", - "proc-macro-error", "proc-macro2", "quote", - "syn", + "syn 2.0.12", ] [[package]] name = "clap_lex" -version = "0.3.2" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "350b9cf31731f9957399229e9b2adc51eeabdfbe9d71d9a0552275fd12710d09" +checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1" + +[[package]] +name = "concolor-override" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a855d4a1978dc52fb0536a04d384c2c0c1aa273597f08b77c8c4d3b2eec6037f" + +[[package]] +name = "concolor-query" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d11d52c3d7ca2e6d0040212be9e4dbbcd78b6447f535b6b561f449427944cf" dependencies = [ - "os_str_bytes", + "windows-sys 0.45.0", ] [[package]] @@ -325,6 +378,12 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "constant_time_eq" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" + [[package]] name = "core-foundation" version = "0.9.3" @@ -343,9 +402,9 @@ checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" [[package]] name = "cpufeatures" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d997bd5e24a5928dd43e46dc529867e207907fe0b239c3477d924f7f2ca320" +checksum = "280a9f2d8b3a38871a3c8a46fb80db65e5e5ed97da80c4d08bf27fb63e35e181" dependencies = [ "libc", ] @@ -424,9 +483,9 @@ dependencies = [ [[package]] name = "darling" -version = "0.10.2" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d706e75d87e35569db781a9b5e2416cff1236a47ed380831f959382ccd5f858" +checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" dependencies = [ "darling_core", "darling_macro", @@ -434,27 +493,27 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.10.2" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0c960ae2da4de88a91b2d920c2a7233b400bc33cb28453a2987822d8392519b" +checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", - "strsim 0.9.3", - "syn", + "strsim", + "syn 1.0.109", ] [[package]] name = "darling_macro" -version = "0.10.2" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b5a2f4ac4969822c62224815d069952656cadc7084fdca9751e6d959189b72" +checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" dependencies = [ "darling_core", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -472,27 +531,33 @@ dependencies = [ [[package]] name = "derive_builder" -version = "0.9.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2658621297f2cf68762a6f7dc0bb7e1ff2cfd6583daef8ee0fed6f7ec468ec0" +checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" dependencies = [ - "darling", - "derive_builder_core", - "proc-macro2", - "quote", - "syn", + "derive_builder_macro", ] [[package]] name = "derive_builder_core" -version = "0.9.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2791ea3e372c8495c0bc2033991d76b512cd799d07491fbd6890124db9458bef" +checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" dependencies = [ "darling", "proc-macro2", "quote", - "syn", + "syn 1.0.109", +] + +[[package]] +name = "derive_builder_macro" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" +dependencies = [ + "derive_builder_core", + "syn 1.0.109", ] [[package]] @@ -503,15 +568,7 @@ checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" dependencies = [ "block-buffer", "crypto-common", -] - -[[package]] -name = "dirs" -version = "3.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30baa043103c9d0c2a57cf537cc2f35623889dc0d405e6c3cccfadbc81c71309" -dependencies = [ - "dirs-sys", + "subtle", ] [[package]] @@ -557,13 +614,13 @@ dependencies = [ [[package]] name = "errno" -version = "0.2.8" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" +checksum = "50d6a0976c999d473fe89ad888d5a284e55366d9dc9038b1ba2aa15128c4afa0" dependencies = [ "errno-dragonfly", "libc", - "winapi", + "windows-sys 0.45.0", ] [[package]] @@ -602,7 +659,7 @@ checksum = "8a3de6e8d11b22ff9edc6d916f890800597d60f8b2da1caf2955c274638d6412" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.2.16", "windows-sys 0.45.0", ] @@ -670,9 +727,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13e2792b0ff0340399d58445b88fd9770e3489eff258a4cbc1523418f12abf84" +checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" dependencies = [ "futures-channel", "futures-core", @@ -685,9 +742,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e5317663a9089767a1ec00a487df42e0ca174b61b4483213ac24448e4664df5" +checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" dependencies = [ "futures-core", "futures-sink", @@ -695,15 +752,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608" +checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" [[package]] name = "futures-executor" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8de0a35a6ab97ec8869e32a2473f4b1324459e14c29275d14b10cb1fd19b50e" +checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" dependencies = [ "futures-core", "futures-task", @@ -712,38 +769,38 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfb8371b6fb2aeb2d280374607aeabfc99d95c72edfe51692e42d3d7f0d08531" +checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" [[package]] name = "futures-macro" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70" +checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.12", ] [[package]] name = "futures-sink" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364" +checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" [[package]] name = "futures-task" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366" +checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" [[package]] name = "futures-util" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1" +checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" dependencies = [ "futures-channel", "futures-core", @@ -759,9 +816,9 @@ dependencies = [ [[package]] name = "generic-array" -version = "0.14.6" +version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bff49e947297f3312447abdca79f45f4738097cc82b06e72054d2223f601f1b9" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", @@ -828,15 +885,6 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" -[[package]] -name = "hermit-abi" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" -dependencies = [ - "libc", -] - [[package]] name = "hermit-abi" version = "0.2.6" @@ -852,6 +900,15 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "http" version = "0.2.9" @@ -894,9 +951,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "hyper" -version = "0.14.24" +version = "0.14.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e011372fa0b68db8350aa7a248930ecc7839bf46d8485577d69f117a75f164c" +checksum = "cc5e554ff619822309ffd57d8734d77cd5ce6238bc956f037ea06c58238c9899" dependencies = [ "bytes", "futures-channel", @@ -959,9 +1016,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.9.2" +version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885e79c1fc4b10f0e172c475f458b7f7b93061064d98c3293e98c5ba0c8b399" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown", @@ -1003,25 +1060,26 @@ dependencies = [ [[package]] name = "io-lifetimes" -version = "1.0.5" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1abeb7a0dd0f8181267ff8adc397075586500b81b28a73e8a0208b00fc170fb3" +checksum = "09270fd4fa1111bc614ed2246c7ef56239a3063d5be0d1ec3b589c505d400aeb" dependencies = [ + "hermit-abi 0.3.1", "libc", "windows-sys 0.45.0", ] [[package]] name = "ipnet" -version = "2.7.1" +version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30e22bd8629359895450b59ea7a776c850561b96a3b1d31321c1949d9e6c9146" +checksum = "12b6ee2129af8d4fb011108c73d99a1b83a85977f23b82460c0ae2e25bb4b57f" [[package]] name = "is-terminal" -version = "0.4.4" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b6b32576413a8e69b90e952e4a026476040d81017b80445deda5f2d3921857" +checksum = "256017f749ab3117e93acb91063009e1f1bb56d03965b14c2c8df4eb02c524d8" dependencies = [ "hermit-abi 0.3.1", "io-lifetimes", @@ -1058,9 +1116,18 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" + +[[package]] +name = "jobserver" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "936cfd212a0155903bcbc060e316fb6cc7cbf2e1907329391ebadc1fe0ce77c2" +dependencies = [ + "libc", +] [[package]] name = "js-sys" @@ -1079,15 +1146,15 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.139" +version = "0.2.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" +checksum = "99227334921fae1a979cf0bfdfcc6b3e5ce376ef57e16fb6fb3ea2ed6095f80c" [[package]] name = "linux-raw-sys" -version = "0.1.4" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" +checksum = "d59d8c75012853d2e872fb56bc8a2e53718e2cafe1a4c823143141c6d90c322f" [[package]] name = "lock_api" @@ -1201,7 +1268,7 @@ checksum = "731f8ecebd9f3a4aa847dfe75455e4757a45da40a7793d2f0b1f9b6ed18b23f3" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -1223,9 +1290,9 @@ dependencies = [ [[package]] name = "mime" -version = "0.3.16" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "mime_guess" @@ -1264,6 +1331,27 @@ dependencies = [ "windows-sys 0.45.0", ] +[[package]] +name = "monostate" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0230b703f1ac35df1e24f6d0d2255472bcccaf657ecdfa4f1fcbcad1ad5bb98a" +dependencies = [ + "monostate-impl", + "serde", +] + +[[package]] +name = "monostate-impl" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.12", +] + [[package]] name = "multimap" version = "0.8.3" @@ -1377,10 +1465,16 @@ dependencies = [ ] [[package]] -name = "openssl" -version = "0.10.45" +name = "opaque-debug" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b102428fd03bc5edf97f62620f7298614c45cedf287c271e7ed450bbaf83f2e1" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + +[[package]] +name = "openssl" +version = "0.10.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "518915b97df115dd36109bfa429a48b8f737bd05508cf9588977b599648926d2" dependencies = [ "bitflags", "cfg-if", @@ -1399,7 +1493,7 @@ checksum = "b501e44f11665960c7e7fcf062c7d96a14ade4aa98116c004b2e37b5be7d736c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -1410,9 +1504,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.80" +version = "0.9.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23bbbf7854cd45b83958ebe919f0e8e516793727652e27fda10a8384cfc790b7" +checksum = "666416d899cf077260dac8698d60a60b435a46d57e82acb1be3d0dad87284e5b" dependencies = [ "autocfg", "cc", @@ -1501,12 +1595,6 @@ dependencies = [ "tokio-stream", ] -[[package]] -name = "os_str_bytes" -version = "6.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" - [[package]] name = "overload" version = "0.1.1" @@ -1531,16 +1619,39 @@ checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.2.16", "smallvec", "windows-sys 0.45.0", ] [[package]] -name = "paste" -version = "1.0.11" +name = "password-hash" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d01a5bd0424d00070b0098dd17ebca6f961a959dead1dbcbbbc1d1cd8d3deeba" +checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + +[[package]] +name = "paste" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" + +[[package]] +name = "pbkdf2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" +dependencies = [ + "digest", + "hmac", + "password-hash", + "sha2", +] [[package]] name = "percent-encoding" @@ -1575,7 +1686,7 @@ checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -1610,12 +1721,12 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "prettyplease" -version = "0.1.23" +version = "0.1.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e97e3215779627f01ee256d2fad52f3d95e8e1c11e9fc6fd08f7cd455d5d5c78" +checksum = "6c8646e95016a7a6c4adea95bafa8a16baab64b583356217f2c85db4a39d9a86" dependencies = [ "proc-macro2", - "syn", + "syn 1.0.109", ] [[package]] @@ -1627,7 +1738,7 @@ dependencies = [ "proc-macro-error-attr", "proc-macro2", "quote", - "syn", + "syn 1.0.109", "version_check", ] @@ -1644,9 +1755,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.51" +version = "1.0.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" +checksum = "e472a104799c74b514a57226160104aa483546de37e839ec50e3c2e41dd87534" dependencies = [ "unicode-ident", ] @@ -1678,7 +1789,7 @@ dependencies = [ "prost", "prost-types", "regex", - "syn", + "syn 1.0.109", "tempfile", "which", ] @@ -1693,7 +1804,7 @@ dependencies = [ "itertools 0.10.5", "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -1723,9 +1834,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.23" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ "proc-macro2", ] @@ -1771,9 +1882,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db3a213adf02b3bcfd2d3846bb41cb22857d131789e01df434fb7e7bc0759b7" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" dependencies = [ "either", "rayon-core", @@ -1792,9 +1903,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.10.2" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "356a0625f1954f730c0201cdab48611198dc6ce21f4acff55089b5a78e6e835b" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" dependencies = [ "crossbeam-channel", "crossbeam-deque", @@ -1811,6 +1922,15 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags", +] + [[package]] name = "redox_users" version = "0.4.3" @@ -1818,15 +1938,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" dependencies = [ "getrandom", - "redox_syscall", + "redox_syscall 0.2.16", "thiserror", ] [[package]] name = "regex" -version = "1.7.1" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" +checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d" dependencies = [ "aho-corasick", "memchr", @@ -1844,15 +1964,15 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.28" +version = "0.6.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "reqwest" -version = "0.11.14" +version = "0.11.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21eed90ec8570952d53b772ecf8f206aa1ec9a3d76b2521c56c42973f2d91ee9" +checksum = "27b71749df584b7f4cac2c426c127a7c785a5106cc98f7a8feb044115f0fa254" dependencies = [ "base64 0.21.0", "bytes", @@ -1887,9 +2007,9 @@ dependencies = [ [[package]] name = "rust-embed" -version = "6.4.2" +version = "6.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "283ffe2f866869428c92e0d61c2f35dfb4355293cdfdc48f49e895c15f1333d1" +checksum = "1b68543d5527e158213414a92832d2aab11a84d2571a5eb021ebe22c43aab066" dependencies = [ "rust-embed-impl", "rust-embed-utils", @@ -1898,23 +2018,23 @@ dependencies = [ [[package]] name = "rust-embed-impl" -version = "6.3.1" +version = "6.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31ab23d42d71fb9be1b643fe6765d292c5e14d46912d13f3ae2815ca048ea04d" +checksum = "4d4e0f0ced47ded9a68374ac145edd65a6c1fa13a96447b873660b2a568a0fd7" dependencies = [ "proc-macro2", "quote", "rust-embed-utils", "shellexpand", - "syn", + "syn 1.0.109", "walkdir", ] [[package]] name = "rust-embed-utils" -version = "7.3.0" +version = "7.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1669d81dfabd1b5f8e2856b8bbe146c6192b0ba22162edc738ac0a5de18f054" +checksum = "512b0ab6853f7e14e3c8754acb43d6f748bb9ced66aa5915a6553ac8213f7731" dependencies = [ "sha2", "walkdir", @@ -1922,9 +2042,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.36.8" +version = "0.37.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43abb88211988493c1abb44a70efa56ff0ce98f233b7b276146f1f3f7ba9644" +checksum = "0e78cc525325c06b4a7ff02db283472f3c042b7ff0c391f96c6d5ac6f4f91b75" dependencies = [ "bitflags", "errno", @@ -1936,15 +2056,15 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70" +checksum = "4f3208ce4d8448b3f3e7d168a73f5e0c43a61e32930de3bceeccedb388b6bf06" [[package]] name = "ryu" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" [[package]] name = "same-file" @@ -1995,29 +2115,29 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.152" +version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb" +checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.152" +version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" +checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.12", ] [[package]] name = "serde_json" -version = "1.0.93" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cad406b69c91885b5107daf2c29572f6c8cdb3c66826821e286c533490c0bc76" +checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744" dependencies = [ "itoa", "ryu", @@ -2026,9 +2146,9 @@ dependencies = [ [[package]] name = "serde_path_to_error" -version = "0.1.9" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341" +checksum = "f7f05c1d5476066defcdfacce1f52fc3cae3af1d3089727100c02ae92e5abbe0" dependencies = [ "serde", ] @@ -2045,6 +2165,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.6" @@ -2071,7 +2202,7 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4" dependencies = [ - "dirs 4.0.0", + "dirs", ] [[package]] @@ -2106,9 +2237,9 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" [[package]] name = "socket2" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a21dcece9b5991cfd1ece74654c8e3d0d5aab499d359b0395e38229c0bb5a3" +checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" dependencies = [ "libc", "winapi", @@ -2132,18 +2263,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" -[[package]] -name = "strsim" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" - -[[package]] -name = "strsim" -version = "0.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c" - [[package]] name = "strsim" version = "0.10.0" @@ -2160,6 +2279,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "subtle" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" + [[package]] name = "syn" version = "1.0.109" @@ -2171,6 +2296,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79d9531f94112cfc3e4c8f5f02cb2b58f72c97b7efd85f70203cc6d8efda5927" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "sync_wrapper" version = "0.1.2" @@ -2190,24 +2326,15 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.4.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af18f7ae1acd354b992402e9ec5864359d693cd8a79dcbef59f76891701c1e95" +checksum = "b9fbec84f381d5795b08656e4912bec604d162bff9291d6189a78f4c8ab87998" dependencies = [ "cfg-if", "fastrand", - "redox_syscall", + "redox_syscall 0.3.5", "rustix", - "windows-sys 0.42.0", -] - -[[package]] -name = "termcolor" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" -dependencies = [ - "winapi-util", + "windows-sys 0.45.0", ] [[package]] @@ -2231,7 +2358,7 @@ dependencies = [ name = "text-generation-launcher" version = "0.4.3" dependencies = [ - "clap 4.1.8", + "clap", "ctrlc", "float_eq", "reqwest", @@ -2249,7 +2376,7 @@ dependencies = [ "async-stream", "axum", "axum-tracing-opentelemetry", - "clap 4.1.8", + "clap", "futures", "metrics", "metrics-exporter-prometheus", @@ -2266,7 +2393,7 @@ dependencies = [ "tokenizers", "tokio", "tokio-stream", - "tower-http 0.3.5", + "tower-http", "tracing", "tracing-opentelemetry", "tracing-subscriber", @@ -2274,33 +2401,24 @@ dependencies = [ "utoipa-swagger-ui", ] -[[package]] -name = "textwrap" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" -dependencies = [ - "unicode-width", -] - [[package]] name = "thiserror" -version = "1.0.38" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a9cd18aa97d5c45c6603caea1da6628790b37f7a34b6ca89522331c5180fed0" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.38" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.12", ] [[package]] @@ -2315,14 +2433,20 @@ dependencies = [ [[package]] name = "time" -version = "0.1.43" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438" +checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890" dependencies = [ - "libc", - "winapi", + "serde", + "time-core", ] +[[package]] +name = "time-core" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd" + [[package]] name = "tinyvec" version = "1.6.0" @@ -2341,14 +2465,13 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokenizers" version = "0.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4ff2dd291eac98dcea13e8cf7a0b28c373a90dc9210ccdab0fa9e69ee0cac69" +source = "git+https://github.com/huggingface/tokenizers.git#3aaf4946b3c82e7a04db6ecde7c8bb4e474e54af" dependencies = [ "aho-corasick", "cached-path", - "clap 2.34.0", + "clap", "derive_builder", - "dirs 3.0.2", + "dirs", "esaxx-rs", "getrandom", "indicatif 0.15.0", @@ -2356,6 +2479,7 @@ dependencies = [ "lazy_static", "log", "macro_rules_attribute", + "monostate", "onig", "paste", "rand", @@ -2375,14 +2499,13 @@ dependencies = [ [[package]] name = "tokio" -version = "1.26.0" +version = "1.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03201d01c3c27a29c8a5cee5b55a93ddae1ccf6f08f65365c2c918f8c1b76f64" +checksum = "d0de47a4eecbe11f498978a9b29d792f0d2692d1dd003650c24c76510e3bc001" dependencies = [ "autocfg", "bytes", "libc", - "memchr", "mio", "num_cpus", "parking_lot", @@ -2405,13 +2528,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "1.8.2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8" +checksum = "61a573bdc87985e9d6ddeed1b3d864e8a302c847e40d647746df2f1de209d1ce" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.12", ] [[package]] @@ -2491,7 +2614,7 @@ dependencies = [ "proc-macro2", "prost-build", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -2533,25 +2656,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "tower-http" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d1d42a9b3f3ec46ba828e8d376aec14592ea199f70a06a548587ecd1c4ab658" -dependencies = [ - "bitflags", - "bytes", - "futures-core", - "futures-util", - "http", - "http-body", - "http-range-header", - "pin-project-lite", - "tower", - "tower-layer", - "tower-service", -] - [[package]] name = "tower-layer" version = "0.3.2" @@ -2585,7 +2689,7 @@ checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -2697,15 +2801,15 @@ dependencies = [ [[package]] name = "unicode-bidi" -version = "0.3.10" +version = "0.3.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54675592c1dbefd78cbd98db9bacd89886e1ca50692a0692baefffdeb92dd58" +checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" -version = "1.0.6" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" +checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" [[package]] name = "unicode-normalization" @@ -2755,10 +2859,16 @@ dependencies = [ ] [[package]] -name = "utoipa" -version = "3.0.3" +name = "utf8parse" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a15f6da6a2b471134ca44b7d18e8a76d73035cf8b3ed24c4dd5ca6a63aa439c5" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + +[[package]] +name = "utoipa" +version = "3.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24e7ee17c9ef094b86e1e04170d90765bd76cb381921dacb4d3e175a267bdae6" dependencies = [ "indexmap", "serde", @@ -2768,21 +2878,21 @@ dependencies = [ [[package]] name = "utoipa-gen" -version = "3.0.3" +version = "3.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f2e33027986a4707b3f5c37ed01b33d0e5a53da30204b52ff18f80600f1d0ec" +checksum = "df6f458e5abc811d44aca28455efc4163fb7565a7af2aa32d17611f3d1d9794d" dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn", + "syn 2.0.12", ] [[package]] name = "utoipa-swagger-ui" -version = "3.0.2" +version = "3.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae3d4f4da6408f0f20ff58196ed619c94306ab32635aeca3d3fa0768c0bd0de2" +checksum = "40a6d6921f30f7a9d023a2523cbff2e94a1a4b535abbdda6e6d0c69b61f19f05" dependencies = [ "axum", "mime_guess", @@ -2791,7 +2901,7 @@ dependencies = [ "serde", "serde_json", "utoipa", - "zip 0.6.4", + "zip", ] [[package]] @@ -2806,12 +2916,6 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" -[[package]] -name = "vec_map" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" - [[package]] name = "version_check" version = "0.9.4" @@ -2820,12 +2924,11 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "walkdir" -version = "2.3.2" +version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56" +checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" dependencies = [ "same-file", - "winapi", "winapi-util", ] @@ -2872,7 +2975,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 1.0.109", "wasm-bindgen-shared", ] @@ -2906,7 +3009,7 @@ checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2995,9 +3098,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -3010,45 +3113,45 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" [[package]] name = "windows_aarch64_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" [[package]] name = "windows_i686_gnu" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" [[package]] name = "windows_i686_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" [[package]] name = "windows_x86_64_gnu" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" [[package]] name = "windows_x86_64_gnullvm" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" [[package]] name = "windows_x86_64_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" [[package]] name = "winreg" @@ -3068,37 +3171,52 @@ dependencies = [ "libc", ] -[[package]] -name = "zip" -version = "0.5.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93ab48844d61251bb3835145c521d88aa4031d7139e8485990f60ca911fa0815" -dependencies = [ - "byteorder", - "bzip2", - "crc32fast", - "flate2", - "thiserror", - "time", -] - [[package]] name = "zip" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef" dependencies = [ + "aes", "byteorder", + "bzip2", + "constant_time_eq", "crc32fast", "crossbeam-utils", "flate2", + "hmac", + "pbkdf2", + "sha1", + "time", + "zstd", ] [[package]] -name = "zip-extensions" -version = "0.6.1" +name = "zstd" +version = "0.11.2+zstd.1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14" +checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" dependencies = [ - "zip 0.5.13", + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "5.0.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.7+zstd.1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94509c3ba2fe55294d752b79842c530ccfab760192521df74a081a78d2b3c7f5" +dependencies = [ + "cc", + "libc", + "pkg-config", ] diff --git a/router/src/main.rs b/router/src/main.rs index 81c6aeef..bad3df93 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -37,7 +37,7 @@ struct Args { max_waiting_tokens: usize, #[clap(default_value = "3000", long, short, env)] port: u16, - #[clap(default_value = "/tmp/text-generation-0", long, env)] + #[clap(default_value = "/tmp/text-generation-server-0", long, env)] master_shard_uds_path: String, #[clap(default_value = "bigscience/bloom", long, env)] tokenizer_name: String, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 5e5610ff..ee977948 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -21,18 +21,21 @@ import torch import torch.distributed -from torch.nn import functional as F - from torch import nn from transformers.activations import ACT2FN +from text_generation_server.models.custom_modeling.tensor_parallel import ( + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelColumnLinear, +) +from text_generation_server.models.custom_modeling.linear import FastLinear +from text_generation_server.models.custom_modeling.rotary import PositionRotaryEmbedding + # Flash attention imports -import rotary_emb import flash_attn_cuda import dropout_layer_norm -from flash_attn.layers.rotary import RotaryEmbedding - class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -84,182 +87,6 @@ class LlamaRMSNorm(nn.Module): return normed_hidden_states, res -class FastLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) - - def transpose_weight(self): - self.weight = nn.Parameter(self.weight.T) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.bias is not None: - return torch.addmm(self.bias, input, self.weight) - return torch.matmul(input, self.weight) - - -class TensorParallelColumnLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_world_size = process_group.size() - assert out_features % self.tp_world_size == 0 - out_features = out_features // self.tp_world_size - - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - - -class TensorParallelRowLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - reduce=True, - bias=True, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_world_size = process_group.size() - self.reduce = reduce - assert in_features % self.tp_world_size == 0 - in_features = in_features // self.tp_world_size - - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - out = super(TensorParallelRowLinear, self).forward(input) - if self.reduce: - torch.distributed.all_reduce(out, group=self.process_group) - - return out - - -class TensorParallelEmbedding(nn.Embedding): - def __init__( - self, - num_embeddings, - embedding_dim, - process_group: torch.distributed.ProcessGroup, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - _weight=None, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - - self.original_num_embeddings = num_embeddings - - assert num_embeddings % self.tp_world_size == 0 - block_size = num_embeddings // self.tp_world_size - # inputs in `[min_id, max_id[` are handled by `self` to get embeddings - self.min_id = self.tp_rank * block_size - self.max_id = (self.tp_rank + 1) * block_size - - # Additional entry that will map to zero - # Used for masking - self.null_idx = block_size - - super().__init__( - block_size, - embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=_weight, - device=device, - dtype=dtype, - ) - - def add_null_idx(self): - """Additional 0 entry used for masking""" - self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - # default all out of bounds values to `self.null_idx` that will then be mapped to 0 - # translate for [0, self.max_id - self.min_id[ - input = torch.where( - (self.min_id > input) | (input >= self.max_id), - self.null_idx, - input - self.min_id, - ) - out = super().forward(input) - torch.distributed.all_reduce(out, group=self.process_group) - return out - - -class PositionRotaryEmbedding(RotaryEmbedding): - def _update_cos_sin_cache(self, dtype, device, seqlen): - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - ): - self._seq_len_cached = seqlen - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device=t.device)) - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - - def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): - """ - Return cos and sin for the asked position ids - """ - - self._update_cos_sin_cache(dtype, position_ids.device, max_s) - - cos = torch.index_select(self._cos_cached, 0, position_ids) - sin = torch.index_select(self._sin_cached, 0, position_ids) - return cos.unsqueeze(1), sin.unsqueeze(1) - - def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): - rotary_dim = cos.shape[-1] - q1 = qkv[:, 0, :, :rotary_dim] - q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] - k1 = qkv[:, 1, :, :rotary_dim] - k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] - - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - return qkv - - class FlashLlamaAttention(torch.nn.Module): def __init__( self, @@ -314,12 +141,12 @@ class FlashLlamaAttention(torch.nn.Module): layer_past[...] = qkv_rot[:, 1:] # output - attn_output = torch.empty_like(qkv[:, 0]) + attn_output = torch.empty_like(qkv_rot[:, 0]) # flash attention flash_attn_cuda.fwd( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], + qkv_rot[:, 0], + qkv_rot[:, 1], + qkv_rot[:, 2], attn_output, cu_seqlens, cu_seqlens, @@ -369,7 +196,12 @@ class LlamaMLP(nn.Module): self.act = ( ACT2FN[act] if "gelu" not in act - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else None, + ) ) if process_group is None: diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index f3517c47..93f1b0ca 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -21,20 +21,23 @@ import torch import torch.distributed -from torch.nn import functional as F - from torch import nn from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig +from text_generation_server.models.custom_modeling.tensor_parallel import ( + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelColumnLinear, +) +from text_generation_server.models.custom_modeling.linear import FastLinear +from text_generation_server.models.custom_modeling.rotary import PositionRotaryEmbedding + # Flash attention imports -import rotary_emb import flash_attn_cuda import dropout_layer_norm -from flash_attn.layers.rotary import RotaryEmbedding - class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): @@ -72,184 +75,6 @@ class FastLayerNorm(nn.LayerNorm): return normed_hidden_states, residual -class FastLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) - - def transpose_weight(self): - self.weight = nn.Parameter(self.weight.T) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.bias is not None: - return torch.addmm(self.bias, input, self.weight) - return torch.matmul(input, self.weight) - - -class TensorParallelColumnLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_world_size = process_group.size() - assert out_features % self.tp_world_size == 0 - out_features = out_features // self.tp_world_size - - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - - -class TensorParallelRowLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - reduce=True, - bias=True, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_world_size = process_group.size() - self.reduce = reduce - assert in_features % self.tp_world_size == 0 - in_features = in_features // self.tp_world_size - - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - out = super(TensorParallelRowLinear, self).forward(input) - if self.reduce: - torch.distributed.all_reduce(out, group=self.process_group) - - return out - - -class TensorParallelEmbedding(nn.Embedding): - def __init__( - self, - num_embeddings, - embedding_dim, - process_group: torch.distributed.ProcessGroup, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - _weight=None, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - - self.original_num_embeddings = num_embeddings - - assert num_embeddings % self.tp_world_size == 0 - block_size = num_embeddings // self.tp_world_size - # inputs in `[min_id, max_id[` are handled by `self` to get embeddings - self.min_id = self.tp_rank * block_size - self.max_id = (self.tp_rank + 1) * block_size - - # Additional entry that will map to zero - # Used for masking - self.null_idx = block_size - - super().__init__( - block_size, - embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=_weight, - device=device, - dtype=dtype, - ) - - def add_null_idx(self): - """Additional 0 entry used for masking""" - self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - # default all out of bounds values to `self.null_idx` that will then be mapped to 0 - # translate for [0, self.max_id - self.min_id[ - input = torch.where( - (self.min_id > input) | (input >= self.max_id), - self.null_idx, - input - self.min_id, - ) - out = super().forward(input) - torch.distributed.all_reduce(out, group=self.process_group) - return out - - -class PositionRotaryEmbedding(RotaryEmbedding): - def _update_cos_sin_cache(self, dtype, device, seqlen): - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - ): - self._seq_len_cached = seqlen - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, self.inv_freq.to(device=t.device)) - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - - def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): - """ - Return cos and sin for the asked position ids - """ - - self._update_cos_sin_cache(dtype, position_ids.device, max_s) - - cos = torch.index_select(self._cos_cached, 0, position_ids) - sin = torch.index_select(self._sin_cached, 0, position_ids) - return cos.unsqueeze(1), sin.unsqueeze(1) - - def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): - rotary_dim = cos.shape[-1] - q1 = qkv[:, 0, :, :rotary_dim] - q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] - k1 = qkv[:, 1, :, :rotary_dim] - k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] - - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - return qkv - - class FlashNeoxAttention(torch.nn.Module): def __init__( self, @@ -376,7 +201,12 @@ class FlashMLP(nn.Module): self.act = ( ACT2FN[act] if "gelu" not in act - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else None, + ) ) if process_group is None: diff --git a/server/text_generation_server/models/custom_modeling/linear.py b/server/text_generation_server/models/custom_modeling/linear.py new file mode 100644 index 00000000..1ca8e3a9 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/linear.py @@ -0,0 +1,22 @@ +import torch +from torch import nn + + +class FastLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) + + def transpose_weight(self): + self.weight = nn.Parameter(self.weight.T) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.bias is not None: + return torch.addmm(self.bias, input, self.weight) + return torch.matmul(input, self.weight) diff --git a/server/text_generation_server/models/custom_modeling/rotary.py b/server/text_generation_server/models/custom_modeling/rotary.py new file mode 100644 index 00000000..69f97558 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/rotary.py @@ -0,0 +1,42 @@ +import torch +import rotary_emb + +from flash_attn.layers.rotary import RotaryEmbedding + + +class PositionRotaryEmbedding(RotaryEmbedding): + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): + """ + Return cos and sin for the asked position ids + """ + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + return cos.unsqueeze(1), sin.unsqueeze(1) + + def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + rotary_dim = cos.shape[-1] + q1 = qkv[:, 0, :, :rotary_dim] + q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] + k1 = qkv[:, 1, :, :rotary_dim] + k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return qkv diff --git a/server/text_generation_server/models/custom_modeling/tensor_parallel.py b/server/text_generation_server/models/custom_modeling/tensor_parallel.py new file mode 100644 index 00000000..e01b8835 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/tensor_parallel.py @@ -0,0 +1,124 @@ +import torch +import torch.distributed +from torch import nn +from torch.nn import functional as F + +from text_generation_server.models.custom_modeling.linear import FastLinear + + +class TensorParallelColumnLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + assert out_features % self.tp_world_size == 0 + out_features = out_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + +class TensorParallelRowLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + reduce=True, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + self.reduce = reduce + assert in_features % self.tp_world_size == 0 + in_features = in_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = super(TensorParallelRowLinear, self).forward(input) + if self.reduce: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + +class TensorParallelEmbedding(nn.Embedding): + def __init__( + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + self.original_num_embeddings = num_embeddings + + assert num_embeddings % self.tp_world_size == 0 + block_size = num_embeddings // self.tp_world_size + # inputs in `[min_id, max_id[` are handled by `self` to get embeddings + self.min_id = self.tp_rank * block_size + self.max_id = (self.tp_rank + 1) * block_size + + # Additional entry that will map to zero + # Used for masking + self.null_idx = block_size + + super().__init__( + block_size, + embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=_weight, + device=device, + dtype=dtype, + ) + + def add_null_idx(self): + """Additional 0 entry used for masking""" + self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 + # translate for [0, self.max_id - self.min_id[ + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) + out = super().forward(input) + torch.distributed.all_reduce(out, group=self.process_group) + return out diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 3029ab89..7d7e2cf5 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -11,6 +11,8 @@ from typing import Optional, Tuple, List from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, +) +from text_generation_server.models.custom_modeling.tensor_parallel import ( TensorParallelEmbedding, TensorParallelRowLinear, TensorParallelColumnLinear, diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index e415a725..655b664e 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -8,12 +8,14 @@ from transformers import AutoTokenizer, AutoConfig from typing import Optional, Tuple, List from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_neox_modeling import ( - FlashGPTNeoXForCausalLM, +from text_generation_server.models.custom_modeling.tensor_parallel import ( TensorParallelEmbedding, TensorParallelRowLinear, TensorParallelColumnLinear, ) +from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + FlashGPTNeoXForCausalLM, +) from text_generation_server.utils import ( initialize_torch_distributed, weight_files,