Merge branch 'main' into mi300-compat

This commit is contained in:
fxmarty 2024-04-26 11:28:42 +02:00
commit 7502367043
54 changed files with 29513 additions and 21069 deletions

248
Cargo.lock generated
View File

@ -120,7 +120,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -159,7 +159,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -170,7 +170,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -449,9 +449,9 @@ checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
[[package]]
name = "cc"
version = "1.0.92"
version = "1.0.94"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2678b2e3449475e95b0aa6f9b506a28e61b3dc8996592b983695e8ebb58a8b41"
checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7"
dependencies = [
"jobserver",
"libc",
@ -510,7 +510,7 @@ dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -665,9 +665,9 @@ dependencies = [
[[package]]
name = "darling"
version = "0.14.4"
version = "0.20.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850"
checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391"
dependencies = [
"darling_core",
"darling_macro",
@ -675,27 +675,27 @@ dependencies = [
[[package]]
name = "darling_core"
version = "0.14.4"
version = "0.20.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0"
checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim 0.10.0",
"syn 1.0.109",
"syn 2.0.60",
]
[[package]]
name = "darling_macro"
version = "0.14.4"
version = "0.20.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f"
dependencies = [
"darling_core",
"quote",
"syn 1.0.109",
"syn 2.0.60",
]
[[package]]
@ -709,33 +709,33 @@ dependencies = [
[[package]]
name = "derive_builder"
version = "0.12.0"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8"
checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7"
dependencies = [
"derive_builder_macro",
]
[[package]]
name = "derive_builder_core"
version = "0.12.0"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f"
checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 1.0.109",
"syn 2.0.60",
]
[[package]]
name = "derive_builder_macro"
version = "0.12.0"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e"
checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b"
dependencies = [
"derive_builder_core",
"syn 1.0.109",
"syn 2.0.60",
]
[[package]]
@ -800,9 +800,9 @@ dependencies = [
[[package]]
name = "either"
version = "1.10.0"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2"
[[package]]
name = "encode_unicode"
@ -1018,7 +1018,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -1423,7 +1423,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -1476,9 +1476,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
[[package]]
name = "jobserver"
version = "0.1.29"
version = "0.1.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f08474e32172238f2827bd160c67871cdb2801430f65c3979184dc362e3ca118"
checksum = "685a7d121ee3f65ae4fddd72b25a04bb36b6af81bc0828f7d5434c0fe60fa3a2"
dependencies = [
"libc",
]
@ -1703,7 +1703,7 @@ checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -1775,9 +1775,9 @@ dependencies = [
[[package]]
name = "monostate"
version = "0.1.11"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "878c2a1f1c70e5724fa28f101ca787b6a7e8ad5c5e4ae4ca3b0fa4a419fa9075"
checksum = "a20fffcd8ca4c69d31e036a71abc400147b41f90895df4edcb36497a1f8af8bf"
dependencies = [
"monostate-impl",
"serde",
@ -1785,13 +1785,13 @@ dependencies = [
[[package]]
name = "monostate-impl"
version = "0.1.11"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f686d68a09079e63b1d2c64aa305095887ce50565f00a922ebfaeeee0d9ba6ce"
checksum = "bf307cbbbd777a9c10cec88ddafee572b3484caad5cce0c9236523c3803105a6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -1929,9 +1929,9 @@ dependencies = [
[[package]]
name = "num"
version = "0.4.1"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
checksum = "3135b08af27d103b0a51f2ae0f8632117b7b185ccf931445affa8df530576a41"
dependencies = [
"num-bigint",
"num-complex",
@ -1981,7 +1981,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -2111,7 +2111,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -2327,7 +2327,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -2381,12 +2381,12 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "prettyplease"
version = "0.2.17"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7"
checksum = "5ac2cf0f2e4f42b49f5ffd07dae8d746508ef7526c13940e5f524012ae6c6550"
dependencies = [
"proc-macro2",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -2415,9 +2415,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.79"
version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e"
checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba"
dependencies = [
"unicode-ident",
]
@ -2438,7 +2438,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
dependencies = [
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -2478,7 +2478,7 @@ dependencies = [
"prost 0.12.4",
"prost-types",
"regex",
"syn 2.0.58",
"syn 2.0.60",
"tempfile",
]
@ -2505,7 +2505,7 @@ dependencies = [
"itertools 0.12.1",
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -2752,12 +2752,6 @@ version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da"
[[package]]
name = "regex-syntax"
version = "0.8.3"
@ -2864,7 +2858,7 @@ dependencies = [
"quote",
"rust-embed-utils",
"shellexpand",
"syn 2.0.58",
"syn 2.0.60",
"walkdir",
]
@ -3038,29 +3032,29 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.197"
version = "1.0.198"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2"
checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.197"
version = "1.0.198"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b"
checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
name = "serde_json"
version = "1.0.115"
version = "1.0.116"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd"
checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813"
dependencies = [
"itoa",
"ryu",
@ -3270,7 +3264,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -3292,9 +3286,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.58"
version = "2.0.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687"
checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3"
dependencies = [
"proc-macro2",
"quote",
@ -3399,7 +3393,7 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
version = "2.0.0"
version = "2.0.1"
dependencies = [
"average",
"clap",
@ -3412,7 +3406,7 @@ dependencies = [
"tabled",
"text-generation-client",
"thiserror",
"tokenizers 0.14.1",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
@ -3420,7 +3414,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "2.0.0"
version = "2.0.1"
dependencies = [
"futures",
"grpc-metadata",
@ -3436,7 +3430,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "2.0.0"
version = "2.0.1"
dependencies = [
"clap",
"ctrlc",
@ -3454,7 +3448,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "2.0.0"
version = "2.0.1"
dependencies = [
"async-stream",
"axum",
@ -3482,7 +3476,7 @@ dependencies = [
"serde_json",
"text-generation-client",
"thiserror",
"tokenizers 0.15.2",
"tokenizers",
"tokio",
"tokio-stream",
"tower-http",
@ -3511,7 +3505,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -3585,46 +3579,11 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.14.1"
version = "0.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9be88c795d8b9f9c4002b3a8f26a6d0876103a6f523b32ea3bac52d8560c17c"
checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd"
dependencies = [
"aho-corasick",
"clap",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.11.0",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.7.5",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]]
name = "tokenizers"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3dd47962b0ba36e7fd33518fbf1754d136fd1474000162bbf2a8b5fcb2d3654d"
dependencies = [
"aho-corasick",
"clap",
"derive_builder",
"esaxx-rs",
"getrandom",
@ -3688,7 +3647,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -3837,7 +3796,7 @@ dependencies = [
"proc-macro2",
"prost-build",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -3910,7 +3869,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -4151,7 +4110,7 @@ dependencies = [
"proc-macro2",
"quote",
"regex",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -4273,7 +4232,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
"wasm-bindgen-shared",
]
@ -4307,7 +4266,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@ -4391,7 +4350,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be"
dependencies = [
"windows-core",
"windows-targets 0.52.4",
"windows-targets 0.52.5",
]
[[package]]
@ -4400,7 +4359,7 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
dependencies = [
"windows-targets 0.52.4",
"windows-targets 0.52.5",
]
[[package]]
@ -4427,7 +4386,7 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
"windows-targets 0.52.4",
"windows-targets 0.52.5",
]
[[package]]
@ -4462,17 +4421,18 @@ dependencies = [
[[package]]
name = "windows-targets"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b"
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
dependencies = [
"windows_aarch64_gnullvm 0.52.4",
"windows_aarch64_msvc 0.52.4",
"windows_i686_gnu 0.52.4",
"windows_i686_msvc 0.52.4",
"windows_x86_64_gnu 0.52.4",
"windows_x86_64_gnullvm 0.52.4",
"windows_x86_64_msvc 0.52.4",
"windows_aarch64_gnullvm 0.52.5",
"windows_aarch64_msvc 0.52.5",
"windows_i686_gnu 0.52.5",
"windows_i686_gnullvm",
"windows_i686_msvc 0.52.5",
"windows_x86_64_gnu 0.52.5",
"windows_x86_64_gnullvm 0.52.5",
"windows_x86_64_msvc 0.52.5",
]
[[package]]
@ -4489,9 +4449,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9"
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
[[package]]
name = "windows_aarch64_msvc"
@ -4507,9 +4467,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675"
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
[[package]]
name = "windows_i686_gnu"
@ -4525,9 +4485,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]]
name = "windows_i686_gnu"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3"
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
[[package]]
name = "windows_i686_msvc"
@ -4543,9 +4509,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]]
name = "windows_i686_msvc"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02"
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
[[package]]
name = "windows_x86_64_gnu"
@ -4561,9 +4527,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03"
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
[[package]]
name = "windows_x86_64_gnullvm"
@ -4579,9 +4545,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177"
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
[[package]]
name = "windows_x86_64_msvc"
@ -4597,9 +4563,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
[[package]]
name = "winnow"
@ -4637,7 +4603,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]

View File

@ -9,11 +9,15 @@ members = [
resolver = "2"
[workspace.package]
version = "2.0.0"
version = "2.0.1"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies]
tokenizers = { version = "0.19.1", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] }
[profile.release]
debug = 1
incremental = true

View File

@ -23,9 +23,9 @@ serde_json = "1.0"
tabled = "0.14.0"
text-generation-client = { path = "../router/client" }
thiserror = "1.0.48"
tokenizers = { version = "0.14.0", features = ["http"] }
tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]}
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
hf-hub = "0.3.1"
hf-hub = { workspace = true }

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "2.0.0"
"version": "2.0.1"
},
"paths": {
"/": {
@ -408,9 +408,14 @@
},
"responses": {
"200": {
"description": "Generated Text",
"description": "Generated Chat Completion",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ChatCompletion"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/ChatCompletionChunk"
}
@ -492,11 +497,16 @@
},
"responses": {
"200": {
"description": "Generated Text",
"description": "Generated Chat Completion",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ChatCompletionChunk"
"$ref": "#/components/schemas/Completion"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/CompletionCompleteChunk"
}
}
}
@ -930,7 +940,7 @@
"tool_prompt": {
"type": "string",
"description": "A prompt to be appended before the tools",
"example": "\"Based on the conversation, please choose the most appropriate tool to use: \"",
"example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"",
"nullable": true
},
"tools": {
@ -1071,7 +1081,10 @@
"example": "mistralai/Mistral-7B-Instruct-v0.2"
},
"prompt": {
"type": "string",
"type": "array",
"items": {
"type": "string"
},
"description": "The prompt to generate completions for.",
"example": "What is Deep Learning?"
},
@ -1234,17 +1247,17 @@
"type": "object",
"required": [
"name",
"parameters"
"arguments"
],
"properties": {
"arguments": {},
"description": {
"type": "string",
"nullable": true
},
"name": {
"type": "string"
},
"parameters": {}
}
}
},
"GenerateParameters": {
@ -1260,7 +1273,7 @@
},
"decoder_input_details": {
"type": "boolean",
"default": "true"
"default": "false"
},
"details": {
"type": "boolean",
@ -1285,6 +1298,7 @@
"$ref": "#/components/schemas/GrammarType"
}
],
"default": "null",
"nullable": true
},
"max_new_tokens": {
@ -1478,6 +1492,7 @@
"max_batch_total_tokens",
"max_waiting_tokens",
"validation_workers",
"max_client_batch_size",
"version"
],
"properties": {
@ -1503,6 +1518,11 @@
"example": "2",
"minimum": 0
},
"max_client_batch_size": {
"type": "integer",
"example": "32",
"minimum": 0
},
"max_concurrent_requests": {
"type": "integer",
"description": "Router Parameters",

View File

@ -1,8 +1,10 @@
# Guidance
Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developer guide LLM responses to fit their needs.
Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developers guide LLM responses to fit their needs.
These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library. The tool support is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
> The Grammar guidance support is currently only available in the TGI API due to lack of support in Open AI API.
## Quick Start
@ -16,7 +18,7 @@ If you're not up to date, grab the latest version and let's get started!
- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.
- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.
- [JSON Schema Integration](#json-schema-integration): Fine grain control over your requests via JSON schema.
- [JSON Schema Integration](#json-schema-integration): Fine-grained control over your requests via JSON schema.
- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.
### Tools and Functions
@ -72,9 +74,9 @@ curl localhost:3000/generate \
```
A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
A grammar can be defined using Pydantic models, JSON schemas, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
> Note: A grammar must compile to an intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
### Constrain with Pydantic
@ -151,7 +153,7 @@ json_schema = {
}
data = {
"inputs": "[INST]convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park [/INST]",
"inputs": "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park",
"parameters": {
"max_new_tokens": 200,
"repetition_penalty": 1.3,

View File

@ -1,5 +1,6 @@
## Speculation
Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea.
The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid.
@ -36,7 +37,7 @@ In order to use medusa models in TGI, simply point to a medusa enabled model, an
If you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`.
Ngram works by trying to find in the previous sequence existing tokens that match, and use those as speculation.
N-gram works by trying to find matching tokens in the previous sequence, and use those as speculation for generating new tokens. For example, if the tokens "np.mean" appear multiple times in the sequence, the model can speculate that the next continuation of the tokens "np." is probably also "mean".
This is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much.

View File

@ -15,7 +15,7 @@ Token streaming is the mode in which the server returns the tokens one by one as
/>
</div>
With token streaming, the server can start returning the tokens one by one before having to generate the whole response. Users can have a sense of the generation's quality earlier than the end of the generation. This has different positive effects:
With token streaming, the server can start returning the tokens one by one before having to generate the whole response. Users can have a sense of the generation's quality before the end of the generation. This has different positive effects:
* Users can get results orders of magnitude earlier for extremely long queries.
* Seeing something in progress allows users to stop the generation if it's not going in the direction they expect.
@ -116,7 +116,7 @@ curl -N 127.0.0.1:8080/generate_stream \
First, we need to install the `@huggingface/inference` library.
`npm install @huggingface/inference`
If you're using the free Inference API, you can use `HfInference`. If you're using inference endpoints, you can use `HfInferenceEndpoint`. Let's
If you're using the free Inference API, you can use `HfInference`. If you're using inference endpoints, you can use `HfInferenceEndpoint`.
We can create a `HfInferenceEndpoint` providing our endpoint URL and credential.

View File

@ -18,8 +18,8 @@ Text Generation Inference implements many optimizations and features, such as:
- Logits warper (temperature scaling, top-p, top-k, repetition penalty)
- Stop sequences
- Log probabilities
- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output.
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance.
- [Guidance](../conceptual/guidance): Enable function calling and tool-use by forcing the model to generate structured outputs based on your own predefined output schemas.
Text Generation Inference is used in production by multiple projects, such as:

View File

@ -293,6 +293,7 @@ def launcher(event_loop):
dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
port = random.randint(8000, 10_000)
@ -334,6 +335,9 @@ def launcher(event_loop):
if max_input_length:
args.append("--max-input-length")
args.append(str(max_input_length))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
if max_total_tokens:
args.append("--max-total-tokens")
args.append(str(max_total_tokens))
@ -371,6 +375,7 @@ def launcher(event_loop):
dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
port = random.randint(8000, 10_000)
@ -395,6 +400,9 @@ def launcher(event_loop):
if max_input_length:
args.append("--max-input-length")
args.append(str(max_input_length))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
if max_total_tokens:
args.append("--max-total-tokens")
args.append(str(max_total_tokens))

View File

@ -17,7 +17,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"system_fingerprint": "2.0.1-native",
"usage": {
"completion_tokens": 100,
"prompt_tokens": 60,

View File

@ -29,7 +29,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"system_fingerprint": "2.0.1-native",
"usage": {
"completion_tokens": 36,
"prompt_tokens": 8,

View File

@ -12,7 +12,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -27,7 +27,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -42,7 +42,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -57,7 +57,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -72,7 +72,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -87,7 +87,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -102,7 +102,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -117,7 +117,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -132,7 +132,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -147,7 +147,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -162,7 +162,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -177,7 +177,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -192,7 +192,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -207,7 +207,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -222,7 +222,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -237,7 +237,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -252,7 +252,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -267,7 +267,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -282,7 +282,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -297,7 +297,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -312,7 +312,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -327,7 +327,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -342,7 +342,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -357,7 +357,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -372,7 +372,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -387,7 +387,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -402,7 +402,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -417,7 +417,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -432,7 +432,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -447,7 +447,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -462,7 +462,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -477,7 +477,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -492,7 +492,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -507,7 +507,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -522,7 +522,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -537,7 +537,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -552,7 +552,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -567,7 +567,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -582,7 +582,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
@ -597,6 +597,6 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
}
]

View File

@ -11,7 +11,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"system_fingerprint": "2.0.1-native",
"usage": {
"completion_tokens": 5,
"prompt_tokens": 6,

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -8.5625,
"text": "Test"
},
{
"id": 2159,
"logprob": -10.78125,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 288,
"logprob": -0.2854004,
"special": false,
"text": "ing"
},
{
"id": 264,
"logprob": -0.37573242,
"special": false,
"text": " a"
},
{
"id": 633,
"logprob": -0.09301758,
"special": false,
"text": " new"
},
{
"id": 4480,
"logprob": -0.3322754,
"special": false,
"text": " feature"
},
{
"id": 297,
"logprob": -0.8510742,
"special": false,
"text": " in"
},
{
"id": 272,
"logprob": -0.13464355,
"special": false,
"text": " the"
},
{
"id": 2039,
"logprob": 0.0,
"special": false,
"text": " game"
},
{
"id": 28723,
"logprob": -0.89990234,
"special": false,
"text": "."
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "Test requesting a new feature in the game.\n\n"
}

View File

@ -0,0 +1,73 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 330,
"logprob": -0.13000488,
"special": false,
"text": " A"
},
{
"id": 13088,
"logprob": -0.6713867,
"special": false,
"text": " chicken"
},
{
"id": 349,
"logprob": -0.2980957,
"special": false,
"text": " is"
},
{
"id": 6398,
"logprob": -0.060638428,
"special": false,
"text": " sitting"
},
{
"id": 356,
"logprob": -0.27319336,
"special": false,
"text": " on"
},
{
"id": 264,
"logprob": -0.140625,
"special": false,
"text": " a"
},
{
"id": 17972,
"logprob": -0.040405273,
"special": false,
"text": " pile"
},
{
"id": 302,
"logprob": -0.0002708435,
"special": false,
"text": " of"
},
{
"id": 2445,
"logprob": -0.095336914,
"special": false,
"text": " money"
},
{
"id": 28723,
"logprob": -0.0068359375,
"special": false,
"text": "."
}
],
"top_tokens": null
},
"generated_text": " A chicken is sitting on a pile of money."
}

View File

@ -30,7 +30,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"system_fingerprint": "2.0.1-native",
"usage": {
"completion_tokens": 37,
"prompt_tokens": 524,

View File

@ -30,7 +30,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"system_fingerprint": "2.0.1-native",
"usage": {
"completion_tokens": 37,
"prompt_tokens": 524,

View File

@ -30,7 +30,7 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"system_fingerprint": "2.0.1-native",
"usage": {
"completion_tokens": 48,
"prompt_tokens": 320,

View File

@ -23,5 +23,5 @@
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
"system_fingerprint": "2.0.1-native"
}

View File

@ -0,0 +1,81 @@
import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module")
def flash_idefics2_next_handle(launcher):
with launcher(
"HuggingFaceM4/idefics2-8b",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_idefics2_next(flash_idefics2_next_handle):
await flash_idefics2_next_handle.health(300)
return flash_idefics2_next_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
chicken = get_chicken()
response = await flash_idefics2_next.generate(
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
max_new_tokens=10,
)
assert (
response.generated_text == " A chicken is sitting on a pile of money."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):
response = await flash_idefics2_next.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_load(
flash_idefics2_next, generate_load, response_snapshot
):
chicken = get_chicken()
responses = await generate_load(
flash_idefics2_next,
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
max_new_tokens=10,
n=4,
)
generated_texts = [r.generated_text for r in responses]
assert generated_texts[0] == " A chicken is sitting on a pile of money."
assert len(generated_texts) == 4
assert all([r.generated_text == generated_texts[0] for r in responses])
assert responses == response_snapshot

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-integration-tests"
version = "2.0.0"
version = "2.0.1"
description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"]

View File

@ -1,71 +1,94 @@
import { check, randomSeed } from 'k6';
import { check } from 'k6';
import { scenario } from 'k6/execution';
import http from 'k6/http';
import { Trend, Counter } from 'k6/metrics';
import { randomItem } from 'https://jslib.k6.io/k6-utils/1.2.0/index.js';
const seed = 0;
const host = __ENV.HOST || '127.0.0.1:8000';
const host = __ENV.HOST;
const model_id = __ENV.MODEL_ID;
const timePerToken = new Trend('time_per_token', true);
const tokens = new Counter('tokens');
const new_tokens = new Counter('new_tokens');
const input_tokens = new Counter('input_tokens');
const max_new_tokens = 50;
randomSeed(seed);
// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json"))
const shareGPT = JSON.parse(open("small.json"))
export function get_options(reference_latency_ms){
export function get_options() {
return {
thresholds: {
http_req_failed: ['rate==0'],
time_per_token: [{
threshold: `p(50)<${5 * reference_latency_ms}`,
abortOnFail: true,
delayAbortEval: '10s'
}],
// time_per_token: [{
// threshold: `p(50)<${5 * reference_latency_ms}`,
// abortOnFail: true,
// delayAbortEval: '10s'
// }],
},
scenarios: {
load_test: {
single_user: {
executor: 'constant-arrival-rate',
duration: '60s',
preAllocatedVUs: 10,
rate: 10,
preAllocatedVUs: 1,
rate: 1,
timeUnit: '1s',
},
// load_test: {
// executor: 'constant-arrival-rate',
// duration: '60s',
// preAllocatedVUs: 100,
// rate: 1,
// timeUnit: '1s',
// },
// breakpoint: {
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows
// preAllocatedVUs: 1000,
// stages: [
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
// ],
// },
// throughput: {
// executor: 'shared-iterations',
// vus: 100,
// iterations: 200,
// maxDuration: '40s',
// },
},
};
}
function generate_payload(gpt, max_new_tokens) {
const input = gpt["conversations"][0]["value"];
return { "messages": [{ "role": "user", "content": input }], "temperature": 0, "model": `${model_id}`, "max_tokens": max_new_tokens }
}
export function run(host, generate_payload, max_new_tokens) {
const headers = {'Content-Type': 'application/json'};
const query = randomItem(shareGPT);
const payload = JSON.stringify(generate_payload(query));
const res = http.post(`http://${host}/generate`, payload, {
export const options = get_options();
export default function run() {
const headers = { 'Content-Type': 'application/json' };
const query = shareGPT[scenario.iterationInTest % shareGPT.length];
const payload = JSON.stringify(generate_payload(query, max_new_tokens));
const res = http.post(`http://${host}/v1/chat/completions`, payload, {
headers,
});
if(res.status >= 400 && res.status < 500){
if (res.status >= 400 && res.status < 500) {
return;
}
check(res, {
'Post status is 200': (r) => res.status === 200,
'Post status is 200': (res) => res.status === 200,
});
const duration = res.timings.duration;
if (res.status === 200) {
const body = res.json();
const n_tokens = body.details.tokens.length;
const latency_ms_per_token = duration / n_tokens;
const completion_tokens = body.usage.completion_tokens;
const latency_ms_per_token = duration / completion_tokens;
timePerToken.add(latency_ms_per_token);
const latency_in_s = latency_ms_per_token / 1000;
const individual_throughput = 1 / latency_in_s;
const _input_tokens = body.details.prefill.length;
tokens.add(n_tokens + _input_tokens);
input_tokens.add(_input_tokens);
new_tokens.add(n_tokens);
const prompt_tokens = body.usage.prompt_tokens;
input_tokens.add(prompt_tokens);
new_tokens.add(completion_tokens);
tokens.add(completion_tokens + prompt_tokens);
}
}

View File

@ -1,17 +0,0 @@
import { get_options, run } from "./common.js";
const reference_latency_ms = 70;
const host = __ENV.HOST || '127.0.0.1:8000';
const max_new_tokens = 50;
function generate_payload(gpt){
const input = gpt["conversations"][0]["value"];
return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "decoder_input_details": true}}
}
export const options = get_options(reference_latency_ms);
export default function(){
run(host, generate_payload, max_new_tokens);
}

View File

@ -1,17 +0,0 @@
import { get_options, run } from "./common.js";
const reference_latency_ms = 22;
const host = __ENV.HOST || '127.0.0.1:8000';
const max_new_tokens = 50;
function generate_payload(gpt){
const input = gpt["conversations"][0]["value"];
return {"prompt": input, "temperature": 0.5, "ignore_eos": true}
}
export const options = get_options(reference_latency_ms);
export default function(){
run(host, generate_payload, max_new_tokens);
}

View File

@ -21,7 +21,7 @@ axum-tracing-opentelemetry = "0.14.1"
text-generation-client = { path = "client" }
clap = { version = "4.4.5", features = ["derive", "env"] }
futures = "0.3.28"
hf-hub = { version = "0.3.0", features = ["tokio"] }
hf-hub = { workspace = true }
jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.21.1"
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
@ -33,7 +33,7 @@ reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188"
serde_json = "1.0.107"
thiserror = "1.0.48"
tokenizers = { version = "0.15.1", features = ["http"] }
tokenizers = { workspace = true}
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14"
tower-http = { version = "0.4.4", features = ["cors"] }

View File

@ -114,8 +114,12 @@ impl Client {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new();
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=");
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)");
}
requests.push(Request {
id: 0,

View File

@ -57,6 +57,31 @@ fn select_best_resolution(
best_fit.unwrap_or((original_height, original_width))
}
fn get_unpadded_features(
height: usize,
width: usize,
npatches: usize,
num_patch_height: usize,
num_patch_width: usize,
) -> (usize, usize) {
let current_height = npatches * num_patch_height;
let current_width = npatches * num_patch_width;
let aspect_ratio: f64 = width as f64 / height as f64;
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
let new_height = (height * current_width) / width;
(new_height, current_width)
} else {
let new_width = (width * current_height) / height;
(current_height, new_width)
};
let unpadded_features = current_height * current_width;
let newline_features = current_height;
(unpadded_features, newline_features)
}
impl LlavaNext {
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let image_size = self.vision_config.image_size;
@ -65,11 +90,9 @@ impl LlavaNext {
let npatches = image_size / patch_size;
let (num_patch_height, num_patch_width) =
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
// Ceil
let height_of_patch = (height * npatches + width - 1) / width;
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
// They are only added after width
let newline_features = height_of_patch * num_patch_width;
let (unpadded_features, newline_features) =
get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);
// The base patch covers the entire image
let base_features = npatches.pow(2);
unpadded_features + newline_features + base_features
@ -84,6 +107,17 @@ pub struct ClipVisionModel {
patch_size: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct Idefics2 {}
impl Idefics2 {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
320
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
@ -92,6 +126,7 @@ pub enum Config {
ClipVisionModel(ClipVisionModel),
Mistral,
Idefics,
Idefics2(Idefics2),
Ssm,
GptBigcode,
Santacoder,
@ -146,13 +181,17 @@ mod test {
],
};
let slots = config.get_number_of_features(20, 20);
assert_eq!(slots, 1176);
let slots = config.get_number_of_features(640, 640);
assert_eq!(slots, 2928);
let slots = config.get_number_of_features(480, 640);
assert_eq!(slots, 2340);
let slots = config.get_number_of_features(899, 1024);
assert_eq!(slots, 2732);
assert_eq!(slots, 2634);
let slots = config.get_number_of_features(1024, 899);
assert_eq!(slots, 3320);
assert_eq!(slots, 2640);
let slots = config.get_number_of_features(1067, 1600);
assert_eq!(slots, 2144);
}
}

View File

@ -73,9 +73,9 @@ pub struct HubTokenizerConfig {
}
impl HubTokenizerConfig {
pub fn from_file(filename: &std::path::Path) -> Self {
let content = std::fs::read_to_string(filename).unwrap();
serde_json::from_str(&content).unwrap_or_default()
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
let content = std::fs::read_to_string(filename).ok()?;
serde_json::from_str(&content).ok()
}
}
@ -116,6 +116,7 @@ mod token_serde {
))
}
}
Value::Null => Ok(None),
_ => Err(de::Error::custom("invalid token format")),
}
}
@ -168,9 +169,12 @@ pub struct Info {
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
pub(crate) struct GenerateParameters {
/// Generate best_of sequences and return the one if the highest token logprobs.
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
pub best_of: Option<usize>,
/// The value used to module the logits distribution.
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
@ -179,6 +183,9 @@ pub(crate) struct GenerateParameters {
example = 0.5
)]
pub temperature: Option<f32>,
/// The parameter for repetition penalty. 1.0 means no penalty.
/// See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
@ -187,6 +194,10 @@ pub(crate) struct GenerateParameters {
example = 1.03
)]
pub repetition_penalty: Option<f32>,
/// The parameter for frequency penalty. 1.0 means no penalty
/// Penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)]
#[schema(
exclusive_minimum = -2.0,
@ -195,9 +206,13 @@ pub(crate) struct GenerateParameters {
example = 0.1
)]
pub frequency_penalty: Option<f32>,
/// The number of highest probability vocabulary tokens to keep for top-k-filtering.
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
pub top_k: Option<i32>,
/// Top-p value for nucleus sampling.
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
@ -207,6 +222,9 @@ pub(crate) struct GenerateParameters {
example = 0.95
)]
pub top_p: Option<f32>,
/// Typical Decoding mass
/// See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
@ -216,30 +234,48 @@ pub(crate) struct GenerateParameters {
example = 0.95
)]
pub typical_p: Option<f32>,
/// Activate logits sampling.
#[serde(default)]
#[schema(default = "false", example = true)]
pub do_sample: bool,
/// Maximum number of tokens to generate.
#[serde(default = "default_max_new_tokens")]
#[schema(nullable = true, default = "100", example = "20")]
pub max_new_tokens: Option<u32>,
/// Whether to prepend the prompt to the generated text
#[serde(default)]
#[schema(nullable = true, default = "null", example = false)]
pub return_full_text: Option<bool>,
/// Stop generating tokens if a member of `stop` is generated.
#[serde(default)]
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
pub stop: Vec<String>,
/// Truncate inputs tokens to the given size.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub truncate: Option<usize>,
/// Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).
#[serde(default)]
#[schema(default = "false", example = true)]
pub watermark: bool,
/// Whether to return generation details.
#[serde(default)]
#[schema(default = "true")]
pub details: bool,
/// Whether to return decoder input token logprobs and ids.
#[serde(default)]
#[schema(default = "true")]
#[schema(default = "false")]
pub decoder_input_details: bool,
/// Random sampling seed.
#[serde(default)]
#[schema(
exclusive_minimum = 0,
@ -248,10 +284,15 @@ pub(crate) struct GenerateParameters {
example = "null"
)]
pub seed: Option<u64>,
/// The number of highest probability vocabulary tokens to keep for top-n-filtering.
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>,
/// Grammar constraints for the generation.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub grammar: Option<GrammarType>,
}
@ -548,7 +589,9 @@ pub(crate) struct ChatCompletionChoice {
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")]
pub role: String,
// TODO Modify this to a true enum.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")]
pub content: Option<String>,
@ -582,6 +625,31 @@ impl ChatCompletionChunk {
logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>,
) -> Self {
let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta {
role: Some("assistant".to_string()),
content: Some(delta),
tool_calls: None,
},
(None, Some(tool_calls)) => ChatCompletionDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tool_calls[0].to_string(),
},
}),
},
(None, None) => ChatCompletionDelta {
role: None,
content: None,
tool_calls: None,
},
};
Self {
id: String::new(),
object: "text_completion".to_string(),
@ -590,19 +658,7 @@ impl ChatCompletionChunk {
system_fingerprint,
choices: vec![ChatCompletionChoice {
index: 0,
delta: ChatCompletionDelta {
role: "assistant".to_string(),
content: delta,
tool_calls: tool_calls.map(|tc| DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tc[0].to_string(),
},
}),
},
delta,
logprobs,
finish_reason,
}],

View File

@ -1,7 +1,7 @@
use axum::http::HeaderValue;
use clap::Parser;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Repo, RepoType};
use hf_hub::{Cache, Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
@ -11,7 +11,7 @@ use opentelemetry_otlp::WithExportConfig;
use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path;
use std::path::{Path, PathBuf};
use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
@ -162,7 +162,6 @@ async fn main() -> Result<(), RouterError> {
// Tokenizer instance
// This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name);
let local_model = local_path.exists() && local_path.is_dir();
// Shared API builder initialization
let api_builder = || {
@ -181,109 +180,113 @@ async fn main() -> Result<(), RouterError> {
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Some(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
None
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = Cache::default();
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
}
}
} else {
None
Type::None
};
// Load tokenizer and model info
let (tokenizer, model_info, config) = if local_model {
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
let model_info = HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
};
let config: Option<Config> = std::fs::read_to_string(local_path.join("config.json"))
.ok()
.as_ref()
.and_then(|c| serde_json::from_str(c).ok());
let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
None,
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
(tokenizer, model_info, config)
} else if let Some(api) = api.clone() {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let tokenizer = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
}
});
(tokenizer, model_info, config)
} else {
// No API and no local model
return Err(RouterError::ArgumentValidation(
"No local model found and no revision specified".to_string(),
));
};
tracing::info!("Using config {config:?}");
// Load tokenizer config if found locally, or check if we can get it from the API if needed
let tokenizer_config = if let Some(path) = tokenizer_config_path {
tracing::info!("Using local tokenizer config from user specified path");
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
} else if local_model {
tracing::info!("Using local tokenizer config");
HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
} else {
match api {
Some(api) => {
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
let repo = Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or("main".to_string()),
);
get_tokenizer_config(&api.repo(repo))
.await
.unwrap_or_else(|| {
tracing::warn!(
"Could not retrieve tokenizer config from the Hugging Face hub."
);
HubTokenizerConfig::default()
})
}
None => {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
}
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
model_info,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
None,
)
}
};
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
});
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
tracing::info!("Using config {config:?}");
if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled");
@ -480,7 +483,7 @@ pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
}
/// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokenizer> {
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer.
@ -497,8 +500,7 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokeniz
"main".to_string(),
));
let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?;
Tokenizer::from_file(tokenizer_filename).ok()
api_base_repo.get("tokenizer.json").await.ok()
} else {
None
}

View File

@ -1000,6 +1000,7 @@ async fn chat_completions(
tools,
tool_choice,
tool_prompt,
temperature,
..
} = req;
@ -1008,6 +1009,11 @@ async fn chat_completions(
let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt.unwrap_or_default();
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
// extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
@ -1054,13 +1060,13 @@ async fn chat_completions(
inputs: inputs.to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: req.temperature,
temperature,
repetition_penalty,
frequency_penalty: req.frequency_penalty,
top_k: None,
top_p: req.top_p,
typical_p: None,
do_sample: true,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
@ -1097,7 +1103,13 @@ async fn chat_completions(
let (content, tool_calls) = if tool_grammar.is_some() {
(None, Some(vec![stream_token.token.text]))
} else {
(Some(stream_token.token.text), None)
let content = if !stream_token.token.special {
Some(stream_token.token.text)
} else {
None
};
(content, None)
};
event

View File

@ -540,7 +540,57 @@ fn prepare_input(
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(),
Some(Config::Idefics2(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str("<fake_token_around_image>");
tokenizer_query.push_str(&"<image>".repeat(slots));
tokenizer_query.push_str("<fake_token_around_image>");
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = 1;
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
_ => inputs.clone(),
};

1081
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-server"
version = "2.0.0"
version = "2.0.1"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
@ -24,9 +24,9 @@ opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
tokenizers = "^0.15.0"
tokenizers = "^0.19.1"
huggingface-hub = "^0.19.3"
transformers = "^4.39"
transformers = "^4.40"
einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true }

View File

@ -5,7 +5,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.3 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
@ -14,7 +14,7 @@ grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -30,15 +30,15 @@ packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.2 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.39.3 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -5,7 +5,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.3 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
@ -14,7 +14,7 @@ grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -30,15 +30,15 @@ packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.2 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.39.3 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -68,6 +68,7 @@ try:
)
from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.llava_next import LlavaNext
from text_generation_server.models.idefics2 import Idefics2
from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi
@ -327,7 +328,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == "llama" or model_type == "baichuan":
elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
if FLASH_ATTENTION:
return FlashLlama(
model_id,
@ -579,6 +580,18 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "idefics2":
if FLASH_ATTENTION:
return Idefics2(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "llava_next":
if FLASH_ATTENTION:

View File

@ -45,58 +45,6 @@ if IS_ROCM_SYSTEM:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class LlamaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=10000.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
@ -108,6 +56,13 @@ def load_attention(config, prefix, weights):
weights=weights,
bias=False,
)
elif config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=False,
)
else:
return TensorParallelColumnLinear.load_multi(
config,
@ -265,13 +220,21 @@ class LlamaMLP(nn.Module):
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
if config.model_type == "phi3":
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config,
prefix=f"{prefix}.gate_up_proj",
weights=weights,
bias=False,
)
else:
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",

View File

@ -409,23 +409,29 @@ class MistralModel(torch.nn.Module):
class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, name=None):
if name is None:
name = "model"
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
f"{name}.embed_tokens"
if not prefix
else f"{prefix}.{name}.embed_tokens"
),
weights=weights,
)
self.model = MistralModel(
prefix="model" if not prefix else f"{prefix}.model",
prefix=name if not prefix else f"{prefix}.{name}",
config=config,
weights=weights,
)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
# TODO dirty hack for idefics2.
prefix=(
"lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
),
weights=weights,
)
self.max_past = config.sliding_window

View File

@ -0,0 +1,829 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Idefics2 model."""
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
import math
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Idefics2VisionEmbeddings(nn.Module):
"""
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
resolution.
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
which allows treating images in their native aspect ratio and without the need to resize them to the same
fixed size. In particular, we start from the original pre-trained SigLIP model
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
"""
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.patch_embedding.bias = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
def forward(
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size,
max_im_w // self.patch_size,
)
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(
fractional_coords_h, boundaries, right=True
)
bucket_coords_w = torch.bucketize(
fractional_coords_w, boundaries, right=True
)
pos_ids = (
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
class Idefics2VisionAttention(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = self.embed_dim // self.num_heads
if self.head_size * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_size**-0.5
self.dropout = config.attention_dropout
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
self.out_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, q_len, _ = hidden_states.size()
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_size * self.num_heads,
self.head_size * self.num_heads,
self.head_size * self.num_heads,
],
dim=2,
)
query_states = query_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
key_states = key_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
value_states = value_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
)
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class Idefics2VisionMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Idefics2EncoderLayer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics2VisionAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
)
self.mlp = Idefics2VisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Idefics2Encoder(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
Idefics2EncoderLayer(
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
)
for i in range(config.num_hidden_layers)
]
)
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
)
return hidden_states
class Idefics2VisionTransformer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embeddings = Idefics2VisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.encoder = Idefics2Encoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
):
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
patch_size = self.config.patch_size
patch_attention_mask = torch.ones(
(
batch_size,
pixel_values.size(2) // patch_size,
pixel_values.size(3) // patch_size,
)
)
patch_attention_mask = patch_attention_mask.to(
dtype=torch.bool, device=pixel_values.device
)
hidden_states = self.embeddings(
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
patch_attention_mask = None
else:
patch_attention_mask = _prepare_4d_attention_mask(
patch_attention_mask, hidden_states.dtype
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=patch_attention_mask,
)
last_hidden_state = encoder_outputs
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
class Idefics2MLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.text_config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
def forward(self, hidden_states):
start_shape = hidden_states.shape[:-1]
gate_up_states = self.gate_up_proj(hidden_states)
intermediate_size = gate_up_states.shape[-1] // 2
gate_up_states = gate_up_states.view(-1, 2, intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
).view(*start_shape, -1)
class Idefics2RMSNorm(nn.Module):
def __init__(self, prefix, weights, eps):
"""
Idefics2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
)
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class Idefics2PerceiverAttention(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.layer_idx = None
self.hidden_size = config.text_config.hidden_size
self.num_heads = config.perceiver_config.resampler_n_heads
self.head_size = config.perceiver_config.resampler_head_dim
self.num_key_value_heads = config.perceiver_config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.attention_dropout = config.perceiver_config.attention_dropout
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
self.num_key_value_heads // weights.process_group.size()
)
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=False,
)
self.kv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
)
self.is_causal = False
def forward(
self,
latents: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = latents.size()
kv_seq_len = q_len + context.size()[1]
hidden_states = torch.concat([context, latents], dim=-2)
query_states = self.q_proj(latents)
kv = self.kv(hidden_states)
key_states, value_states = kv.split(
[
self.head_size * self.num_key_value_heads,
self.head_size * self.num_key_value_heads,
],
dim=2,
)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_size
).transpose(1, 2)
key_states = key_states.view(
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
).transpose(1, 2)
value_states = value_states.view(
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
).transpose(1, 2)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_size)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)
attn_output = self.o_proj(attn_output)
return attn_output
class Idefics2PerceiverLayer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.hidden_size = config.text_config.hidden_size
self.n_latents = config.perceiver_config.resampler_n_latents
self.depth = config.perceiver_config.resampler_depth
self.rms_norm_eps = config.text_config.rms_norm_eps
self.input_latents_norm = Idefics2RMSNorm(
prefix=f"{prefix}.input_latents_norm",
weights=weights,
eps=self.rms_norm_eps,
)
self.input_context_norm = Idefics2RMSNorm(
prefix=f"{prefix}.input_context_norm",
weights=weights,
eps=self.rms_norm_eps,
)
self.self_attn = Idefics2PerceiverAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.post_attention_layernorm = Idefics2RMSNorm(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=self.rms_norm_eps,
)
self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
def forward(
self,
latents: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
"""
Args:
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
"""
residual = latents
latents = self.input_latents_norm(latents)
context = self.input_context_norm(context)
latents = self.self_attn(
latents=latents,
context=context,
attention_mask=attention_mask,
)
latents = residual + latents
residual = latents
latents = self.post_attention_layernorm(latents)
latents = self.mlp(latents)
latents = residual + latents
return latents
class Idefics2PerceiverResampler(nn.Module):
def __init__(self, prefix, config, weights) -> None:
super().__init__()
self.hidden_size = config.text_config.hidden_size
self.hidden_act = config.perceiver_config.hidden_act
self.n_latents = config.perceiver_config.resampler_n_latents
self.depth = config.perceiver_config.resampler_depth
self.rms_norm_eps = config.text_config.rms_norm_eps
# Create Latents for Perceiver
self.latents = weights.get_tensor(f"{prefix}.latents")
# Create Transformer Blocks
self.layers = nn.ModuleList(
[
Idefics2PerceiverLayer(
prefix=f"{prefix}.layers.{idx}", config=config, weights=weights
)
for idx in range(self.depth)
]
)
self.norm = Idefics2RMSNorm(
prefix=f"{prefix}.norm",
weights=weights,
eps=config.text_config.rms_norm_eps,
)
def forward(
self,
context: torch.Tensor,
attention_mask,
) -> torch.Tensor:
# seq embed -> bsz seq embed
latents = self.latents.unsqueeze(0).expand(
(context.shape[0], *self.latents.size())
)
latent_attention_mask = torch.ones(
(attention_mask.size(0), latents.size(1)),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
attention_mask = _prepare_4d_attention_mask(
attention_mask, latents.dtype, tgt_len=self.n_latents
)
compressed_context = latents
for perceiver_layer in self.layers:
compressed_context = perceiver_layer(
compressed_context,
context,
attention_mask=attention_mask,
)
compressed_context = self.norm(compressed_context)
return compressed_context
class Idefics2Connector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.modality_projection = Idefics2MLP(
prefix=f"{prefix}.modality_projection", config=config, weights=weights
)
self.perceiver_resampler = Idefics2PerceiverResampler(
prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights
)
def forward(self, image_hidden_states, attention_mask):
image_hidden_states = self.modality_projection(image_hidden_states)
image_hidden_states = self.perceiver_resampler(
context=image_hidden_states, attention_mask=attention_mask
)
return image_hidden_states
class Idefics2ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
config.vision_config.use_medusa = config.use_medusa
config.text_config.quantize = config.quantize
config.text_config.use_medusa = config.use_medusa
vision_config = config.vision_config
self.text_model = load_text_model(
prefix="model" if not prefix else f"{prefix}.model",
config=config.text_config,
weights=weights,
name="text_model",
)
self.dtype = weights.dtype
self.vision_model = Idefics2VisionTransformer(
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
config=vision_config,
weights=weights,
)
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
self.config = config
self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = config.image_token_id
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
# mask = input_ids == self.config.image_token_index
mask = input_ids == self.config.image_token_id
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -23,6 +23,10 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
@ -105,36 +109,6 @@ class LlavaNextMultiModalProjector(nn.Module):
return hidden_states
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_text_model(prefix, config, weights):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
@ -180,7 +154,12 @@ class LlavaNextForConditionalGeneration(nn.Module):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds
def forward(
@ -196,6 +175,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.language_model.embed_tokens(input_ids)

View File

@ -0,0 +1,28 @@
def load_text_model(prefix, config, weights, name=None):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights, name=name)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")

View File

@ -2,14 +2,13 @@ import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from transformers.models.llama import LlamaTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
LlamaConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
@ -53,8 +52,17 @@ class FlashLlama(FlashCausalLM):
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = LlamaConfig.from_pretrained(
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize

View File

@ -511,18 +511,33 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
if cu_seqlen_prefill is None:
logits, speculative_logits = self.compiled_model(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
else:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits, speculative_logits

View File

@ -0,0 +1,51 @@
import torch
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class Idefics2(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
size={"longest_edge": 448, "shortest_edge": 378},
)
super().__init__(
model_cls=Idefics2ForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)

View File

@ -1,6 +1,6 @@
import torch
from typing import Optional
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
@ -34,3 +34,13 @@ class LlavaNext(VlmCausalLM):
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)

View File

@ -27,7 +27,14 @@ class Model(ABC):
):
self.model = model.eval()
self.tokenizer = tokenizer
# all_special_ids is not set correctly if the rust tokenizer is unpacked
# TODO report this to transformers.
other_special_ids = {
id for id, token in tokenizer.added_tokens_decoder.items() if token.special
}
self.all_special_ids = set(tokenizer.all_special_ids)
self.all_special_ids.update(other_special_ids)
self.requires_padding = requires_padding
self.dtype = dtype
self.device = device

View File

@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size
def image_text_replacement(image_input, config, image_id) -> str:
if config.model_type == "idefics2":
# TODO technically depends on image splitting which is not implemented.
num_features = 320
return (
"<fake_token_around_image>"
+ "<image>" * num_features
+ "<fake_token_around_image>"
)
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config)
from loguru import logger
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
return "<image>" * num_features
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def get_unpadded_features(
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
current_height = new_height
else:
new_width = (width * current_height) // height
current_width = new_width
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def get_number_of_features(height: int, width: int, config) -> int:
# From config
# Hardcoded for CLIP for now
@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
image_grid_pinpoints,
image_size,
)
height_of_patch = math.ceil(height / width * npatches)
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
# They are only added after width
newline_features = height_of_patch * num_patch_width
unpadded_features, newline_features = get_unpadded_features(
height, width, npatches, num_patch_height, num_patch_width
)
# The base patch covers the entire image
base_features = npatches**2
return unpadded_features + newline_features + base_features
@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
return image
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
# assert get_number_of_features(640, 640) == 2928
class VlmCausalLMBatch(FlashMistralBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
@classmethod
@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def filter(self, request_ids: List[int]):
batch = super().filter(request_ids)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += chunk["content"]
@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
"Cannot process input image not starting with data:"
)
image_input = processor.image_processor(image, return_tensors="pt")
height, width = image_input["image_sizes"][0]
num_features = get_number_of_features(height, width, config)
full_text += "<image>" * num_features
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
if image_inputs:
image_inputs = {
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
else:
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral):
def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)
def forward(
self, batch: VlmCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral):
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral):
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits

View File

@ -756,6 +756,19 @@ class TensorParallelHead(SuperLayer):
class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_gate_up(
prefix, quantize=config.quantize
)
if bias:
raise NotImplementedError("packed_gate_up only implemented without bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
return cls(linear)
@classmethod
def load_qkv(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact"""

View File

@ -143,6 +143,8 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
# set score to 0 where input_ids is a padding token
score *= input_ids.ne(0)
return scores.scatter_add_(1, input_ids, score)
@ -168,6 +170,8 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
score = -torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
)
# set score to 0 where input_ids is a padding token
score *= input_ids.ne(0)
return scores.scatter_add_(1, input_ids, score)

View File

@ -1,7 +1,6 @@
import torch
# vllm imports
from vllm._C import cache_ops, ops
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
_PARTITION_SIZE = 512
@ -13,7 +12,18 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slots: torch.Tensor,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
if IS_CUDA_SYSTEM:
from vllm._C import cache_ops
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
elif IS_ROCM_SYSTEM:
from vllm import cache_ops
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
else:
raise ValueError("vllm is not supported on your system")
def attention(
@ -55,21 +65,43 @@ def attention(
# to parallelize.
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
if IS_CUDA_SYSTEM:
from vllm._C import ops
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
elif IS_ROCM_SYSTEM:
from vllm import attention_ops
attention_ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
else:
raise ValueError("vllm is not supported on your system")
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
@ -84,21 +116,46 @@ def attention(
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
if IS_CUDA_SYSTEM:
from vllm._C import ops
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
elif IS_ROCM_SYSTEM:
from vllm import attention_ops
attention_ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
else:
raise ValueError("vllm is not supported on your system")

View File

@ -1,5 +1,5 @@
import re
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Set, Union
import math
import torch
@ -143,12 +143,22 @@ class StopSequenceCriteria:
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
eos_token_ids: Optional[Union[Set[int], int]],
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20,
ignore_eos_token: bool = False,
):
self.eos_token_id = eos_token_id
if eos_token_ids is None:
eos_token_ids = set()
elif isinstance(eos_token_ids, int):
eos_token_ids = set([eos_token_ids])
elif isinstance(eos_token_ids, set):
eos_token_ids = eos_token_ids
else:
raise RuntimeError(
f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
)
self.eos_token_ids = eos_token_ids
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
@ -160,7 +170,10 @@ class StoppingCriteria:
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if not self.ignore_eos_token and last_token == self.eos_token_id:
if isinstance(last_token, torch.Tensor):
last_token = last_token.item()
if not self.ignore_eos_token and last_token in self.eos_token_ids:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
if self.stop_sequence_criterias:
@ -184,8 +197,10 @@ class StoppingCriteria:
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
# TODO Hack because eos_token_id cannot be what we want.
eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
return StoppingCriteria(
tokenizer.eos_token_id,
eos_token_id,
stop_sequence_criterias,
pb.max_new_tokens,
pb.ignore_eos_token,
@ -273,7 +288,7 @@ class HeterogeneousNextTokenChooser:
else None
)
if any([x != 1.0 for x in temperature]):
if any(x != 1.0 for x in temperature):
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
]
@ -281,15 +296,15 @@ class HeterogeneousNextTokenChooser:
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
)
if any([x != 0 for x in top_k]):
if any(x != 0 for x in top_k):
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
if any([x < 1.0 for x in top_p]):
if any(x < 1.0 for x in top_p):
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
if any([x < 1.0 for x in typical_p]):
if any(x < 1.0 for x in typical_p):
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))

View File

@ -141,6 +141,12 @@ class Weights:
return weight
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 3)
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 2)
def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
"""
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor
@ -181,8 +187,8 @@ class Weights:
else:
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
single_size = total_size // 3
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // blocks
world_size = self.process_group.size()
rank = self.process_group.rank()
@ -192,10 +198,11 @@ class Weights:
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q = slice_[start:stop]
k = slice_[start + single_size : stop + single_size]
v = slice_[start + 2 * single_size : stop + 2 * single_size]
weight = torch.cat([q, k, v], dim=0)
tensors = []
for i in range(blocks):
tensor = slice_[start + i * single_size : stop + i * single_size]
tensors.append(tensor)
weight = torch.cat(tensors, dim=0)
weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype)
return weight