Merge branch 'main' into mi300-compat

This commit is contained in:
fxmarty 2024-04-26 11:28:42 +02:00
commit 7502367043
54 changed files with 29513 additions and 21069 deletions

248
Cargo.lock generated
View File

@ -120,7 +120,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -159,7 +159,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -170,7 +170,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -449,9 +449,9 @@ checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.0.92" version = "1.0.94"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2678b2e3449475e95b0aa6f9b506a28e61b3dc8996592b983695e8ebb58a8b41" checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7"
dependencies = [ dependencies = [
"jobserver", "jobserver",
"libc", "libc",
@ -510,7 +510,7 @@ dependencies = [
"heck 0.5.0", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -665,9 +665,9 @@ dependencies = [
[[package]] [[package]]
name = "darling" name = "darling"
version = "0.14.4" version = "0.20.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391"
dependencies = [ dependencies = [
"darling_core", "darling_core",
"darling_macro", "darling_macro",
@ -675,27 +675,27 @@ dependencies = [
[[package]] [[package]]
name = "darling_core" name = "darling_core"
version = "0.14.4" version = "0.20.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f"
dependencies = [ dependencies = [
"fnv", "fnv",
"ident_case", "ident_case",
"proc-macro2", "proc-macro2",
"quote", "quote",
"strsim 0.10.0", "strsim 0.10.0",
"syn 1.0.109", "syn 2.0.60",
] ]
[[package]] [[package]]
name = "darling_macro" name = "darling_macro"
version = "0.14.4" version = "0.20.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f"
dependencies = [ dependencies = [
"darling_core", "darling_core",
"quote", "quote",
"syn 1.0.109", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -709,33 +709,33 @@ dependencies = [
[[package]] [[package]]
name = "derive_builder" name = "derive_builder"
version = "0.12.0" version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7"
dependencies = [ dependencies = [
"derive_builder_macro", "derive_builder_macro",
] ]
[[package]] [[package]]
name = "derive_builder_core" name = "derive_builder_core"
version = "0.12.0" version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d"
dependencies = [ dependencies = [
"darling", "darling",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 1.0.109", "syn 2.0.60",
] ]
[[package]] [[package]]
name = "derive_builder_macro" name = "derive_builder_macro"
version = "0.12.0" version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b"
dependencies = [ dependencies = [
"derive_builder_core", "derive_builder_core",
"syn 1.0.109", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -800,9 +800,9 @@ dependencies = [
[[package]] [[package]]
name = "either" name = "either"
version = "1.10.0" version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2"
[[package]] [[package]]
name = "encode_unicode" name = "encode_unicode"
@ -1018,7 +1018,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -1423,7 +1423,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -1476,9 +1476,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
[[package]] [[package]]
name = "jobserver" name = "jobserver"
version = "0.1.29" version = "0.1.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f08474e32172238f2827bd160c67871cdb2801430f65c3979184dc362e3ca118" checksum = "685a7d121ee3f65ae4fddd72b25a04bb36b6af81bc0828f7d5434c0fe60fa3a2"
dependencies = [ dependencies = [
"libc", "libc",
] ]
@ -1703,7 +1703,7 @@ checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -1775,9 +1775,9 @@ dependencies = [
[[package]] [[package]]
name = "monostate" name = "monostate"
version = "0.1.11" version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "878c2a1f1c70e5724fa28f101ca787b6a7e8ad5c5e4ae4ca3b0fa4a419fa9075" checksum = "a20fffcd8ca4c69d31e036a71abc400147b41f90895df4edcb36497a1f8af8bf"
dependencies = [ dependencies = [
"monostate-impl", "monostate-impl",
"serde", "serde",
@ -1785,13 +1785,13 @@ dependencies = [
[[package]] [[package]]
name = "monostate-impl" name = "monostate-impl"
version = "0.1.11" version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f686d68a09079e63b1d2c64aa305095887ce50565f00a922ebfaeeee0d9ba6ce" checksum = "bf307cbbbd777a9c10cec88ddafee572b3484caad5cce0c9236523c3803105a6"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -1929,9 +1929,9 @@ dependencies = [
[[package]] [[package]]
name = "num" name = "num"
version = "0.4.1" version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af" checksum = "3135b08af27d103b0a51f2ae0f8632117b7b185ccf931445affa8df530576a41"
dependencies = [ dependencies = [
"num-bigint", "num-bigint",
"num-complex", "num-complex",
@ -1981,7 +1981,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -2111,7 +2111,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -2327,7 +2327,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -2381,12 +2381,12 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]] [[package]]
name = "prettyplease" name = "prettyplease"
version = "0.2.17" version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7" checksum = "5ac2cf0f2e4f42b49f5ffd07dae8d746508ef7526c13940e5f524012ae6c6550"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -2415,9 +2415,9 @@ dependencies = [
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.79" version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba"
dependencies = [ dependencies = [
"unicode-ident", "unicode-ident",
] ]
@ -2438,7 +2438,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
dependencies = [ dependencies = [
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -2478,7 +2478,7 @@ dependencies = [
"prost 0.12.4", "prost 0.12.4",
"prost-types", "prost-types",
"regex", "regex",
"syn 2.0.58", "syn 2.0.60",
"tempfile", "tempfile",
] ]
@ -2505,7 +2505,7 @@ dependencies = [
"itertools 0.12.1", "itertools 0.12.1",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -2752,12 +2752,6 @@ version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da"
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
version = "0.8.3" version = "0.8.3"
@ -2864,7 +2858,7 @@ dependencies = [
"quote", "quote",
"rust-embed-utils", "rust-embed-utils",
"shellexpand", "shellexpand",
"syn 2.0.58", "syn 2.0.60",
"walkdir", "walkdir",
] ]
@ -3038,29 +3032,29 @@ dependencies = [
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.197" version = "1.0.198"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc"
dependencies = [ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.197" version = "1.0.198"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.115" version = "1.0.116"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd" checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813"
dependencies = [ dependencies = [
"itoa", "itoa",
"ryu", "ryu",
@ -3270,7 +3264,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustversion", "rustversion",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -3292,9 +3286,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.58" version = "2.0.60"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687" checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -3399,7 +3393,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "2.0.0" version = "2.0.1"
dependencies = [ dependencies = [
"average", "average",
"clap", "clap",
@ -3412,7 +3406,7 @@ dependencies = [
"tabled", "tabled",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers 0.14.1", "tokenizers",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@ -3420,7 +3414,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "2.0.0" version = "2.0.1"
dependencies = [ dependencies = [
"futures", "futures",
"grpc-metadata", "grpc-metadata",
@ -3436,7 +3430,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "2.0.0" version = "2.0.1"
dependencies = [ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
@ -3454,7 +3448,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "2.0.0" version = "2.0.1"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum", "axum",
@ -3482,7 +3476,7 @@ dependencies = [
"serde_json", "serde_json",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers 0.15.2", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tower-http", "tower-http",
@ -3511,7 +3505,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -3585,46 +3579,11 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]] [[package]]
name = "tokenizers" name = "tokenizers"
version = "0.14.1" version = "0.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9be88c795d8b9f9c4002b3a8f26a6d0876103a6f523b32ea3bac52d8560c17c" checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"clap",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.11.0",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.7.5",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]]
name = "tokenizers"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3dd47962b0ba36e7fd33518fbf1754d136fd1474000162bbf2a8b5fcb2d3654d"
dependencies = [
"aho-corasick",
"clap",
"derive_builder", "derive_builder",
"esaxx-rs", "esaxx-rs",
"getrandom", "getrandom",
@ -3688,7 +3647,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -3837,7 +3796,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"prost-build", "prost-build",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -3910,7 +3869,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -4151,7 +4110,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"regex", "regex",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]
@ -4273,7 +4232,7 @@ dependencies = [
"once_cell", "once_cell",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@ -4307,7 +4266,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
"wasm-bindgen-backend", "wasm-bindgen-backend",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@ -4391,7 +4350,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be"
dependencies = [ dependencies = [
"windows-core", "windows-core",
"windows-targets 0.52.4", "windows-targets 0.52.5",
] ]
[[package]] [[package]]
@ -4400,7 +4359,7 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
dependencies = [ dependencies = [
"windows-targets 0.52.4", "windows-targets 0.52.5",
] ]
[[package]] [[package]]
@ -4427,7 +4386,7 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [ dependencies = [
"windows-targets 0.52.4", "windows-targets 0.52.5",
] ]
[[package]] [[package]]
@ -4462,17 +4421,18 @@ dependencies = [
[[package]] [[package]]
name = "windows-targets" name = "windows-targets"
version = "0.52.4" version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
dependencies = [ dependencies = [
"windows_aarch64_gnullvm 0.52.4", "windows_aarch64_gnullvm 0.52.5",
"windows_aarch64_msvc 0.52.4", "windows_aarch64_msvc 0.52.5",
"windows_i686_gnu 0.52.4", "windows_i686_gnu 0.52.5",
"windows_i686_msvc 0.52.4", "windows_i686_gnullvm",
"windows_x86_64_gnu 0.52.4", "windows_i686_msvc 0.52.5",
"windows_x86_64_gnullvm 0.52.4", "windows_x86_64_gnu 0.52.5",
"windows_x86_64_msvc 0.52.4", "windows_x86_64_gnullvm 0.52.5",
"windows_x86_64_msvc 0.52.5",
] ]
[[package]] [[package]]
@ -4489,9 +4449,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]] [[package]]
name = "windows_aarch64_gnullvm" name = "windows_aarch64_gnullvm"
version = "0.52.4" version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
[[package]] [[package]]
name = "windows_aarch64_msvc" name = "windows_aarch64_msvc"
@ -4507,9 +4467,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]] [[package]]
name = "windows_aarch64_msvc" name = "windows_aarch64_msvc"
version = "0.52.4" version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
[[package]] [[package]]
name = "windows_i686_gnu" name = "windows_i686_gnu"
@ -4525,9 +4485,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]] [[package]]
name = "windows_i686_gnu" name = "windows_i686_gnu"
version = "0.52.4" version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
[[package]] [[package]]
name = "windows_i686_msvc" name = "windows_i686_msvc"
@ -4543,9 +4509,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]] [[package]]
name = "windows_i686_msvc" name = "windows_i686_msvc"
version = "0.52.4" version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
[[package]] [[package]]
name = "windows_x86_64_gnu" name = "windows_x86_64_gnu"
@ -4561,9 +4527,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]] [[package]]
name = "windows_x86_64_gnu" name = "windows_x86_64_gnu"
version = "0.52.4" version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
[[package]] [[package]]
name = "windows_x86_64_gnullvm" name = "windows_x86_64_gnullvm"
@ -4579,9 +4545,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]] [[package]]
name = "windows_x86_64_gnullvm" name = "windows_x86_64_gnullvm"
version = "0.52.4" version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
[[package]] [[package]]
name = "windows_x86_64_msvc" name = "windows_x86_64_msvc"
@ -4597,9 +4563,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
[[package]] [[package]]
name = "windows_x86_64_msvc" name = "windows_x86_64_msvc"
version = "0.52.4" version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
[[package]] [[package]]
name = "winnow" name = "winnow"
@ -4637,7 +4603,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.58", "syn 2.0.60",
] ]
[[package]] [[package]]

View File

@ -9,11 +9,15 @@ members = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "2.0.0" version = "2.0.1"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies]
tokenizers = { version = "0.19.1", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] }
[profile.release] [profile.release]
debug = 1 debug = 1
incremental = true incremental = true

View File

@ -23,9 +23,9 @@ serde_json = "1.0"
tabled = "0.14.0" tabled = "0.14.0"
text-generation-client = { path = "../router/client" } text-generation-client = { path = "../router/client" }
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { version = "0.14.0", features = ["http"] } tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]} tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]}
tracing = "0.1.37" tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
hf-hub = "0.3.1" hf-hub = { workspace = true }

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "2.0.0" "version": "2.0.1"
}, },
"paths": { "paths": {
"/": { "/": {
@ -408,9 +408,14 @@
}, },
"responses": { "responses": {
"200": { "200": {
"description": "Generated Text", "description": "Generated Chat Completion",
"content": { "content": {
"application/json": { "application/json": {
"schema": {
"$ref": "#/components/schemas/ChatCompletion"
}
},
"text/event-stream": {
"schema": { "schema": {
"$ref": "#/components/schemas/ChatCompletionChunk" "$ref": "#/components/schemas/ChatCompletionChunk"
} }
@ -492,11 +497,16 @@
}, },
"responses": { "responses": {
"200": { "200": {
"description": "Generated Text", "description": "Generated Chat Completion",
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/ChatCompletionChunk" "$ref": "#/components/schemas/Completion"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/CompletionCompleteChunk"
} }
} }
} }
@ -930,7 +940,7 @@
"tool_prompt": { "tool_prompt": {
"type": "string", "type": "string",
"description": "A prompt to be appended before the tools", "description": "A prompt to be appended before the tools",
"example": "\"Based on the conversation, please choose the most appropriate tool to use: \"", "example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"",
"nullable": true "nullable": true
}, },
"tools": { "tools": {
@ -1071,7 +1081,10 @@
"example": "mistralai/Mistral-7B-Instruct-v0.2" "example": "mistralai/Mistral-7B-Instruct-v0.2"
}, },
"prompt": { "prompt": {
"type": "string", "type": "array",
"items": {
"type": "string"
},
"description": "The prompt to generate completions for.", "description": "The prompt to generate completions for.",
"example": "What is Deep Learning?" "example": "What is Deep Learning?"
}, },
@ -1234,17 +1247,17 @@
"type": "object", "type": "object",
"required": [ "required": [
"name", "name",
"parameters" "arguments"
], ],
"properties": { "properties": {
"arguments": {},
"description": { "description": {
"type": "string", "type": "string",
"nullable": true "nullable": true
}, },
"name": { "name": {
"type": "string" "type": "string"
}, }
"parameters": {}
} }
}, },
"GenerateParameters": { "GenerateParameters": {
@ -1260,7 +1273,7 @@
}, },
"decoder_input_details": { "decoder_input_details": {
"type": "boolean", "type": "boolean",
"default": "true" "default": "false"
}, },
"details": { "details": {
"type": "boolean", "type": "boolean",
@ -1285,6 +1298,7 @@
"$ref": "#/components/schemas/GrammarType" "$ref": "#/components/schemas/GrammarType"
} }
], ],
"default": "null",
"nullable": true "nullable": true
}, },
"max_new_tokens": { "max_new_tokens": {
@ -1478,6 +1492,7 @@
"max_batch_total_tokens", "max_batch_total_tokens",
"max_waiting_tokens", "max_waiting_tokens",
"validation_workers", "validation_workers",
"max_client_batch_size",
"version" "version"
], ],
"properties": { "properties": {
@ -1503,6 +1518,11 @@
"example": "2", "example": "2",
"minimum": 0 "minimum": 0
}, },
"max_client_batch_size": {
"type": "integer",
"example": "32",
"minimum": 0
},
"max_concurrent_requests": { "max_concurrent_requests": {
"type": "integer", "type": "integer",
"description": "Router Parameters", "description": "Router Parameters",

View File

@ -1,8 +1,10 @@
# Guidance # Guidance
Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developer guide LLM responses to fit their needs. Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developers guide LLM responses to fit their needs.
These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them! These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library. The tool support is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
> The Grammar guidance support is currently only available in the TGI API due to lack of support in Open AI API.
## Quick Start ## Quick Start
@ -16,7 +18,7 @@ If you're not up to date, grab the latest version and let's get started!
- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision. - [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.
- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models. - [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.
- [JSON Schema Integration](#json-schema-integration): Fine grain control over your requests via JSON schema. - [JSON Schema Integration](#json-schema-integration): Fine-grained control over your requests via JSON schema.
- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses. - [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.
### Tools and Functions ### Tools and Functions
@ -72,9 +74,9 @@ curl localhost:3000/generate \
``` ```
A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar. A grammar can be defined using Pydantic models, JSON schemas, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster. > Note: A grammar must compile to an intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
### Constrain with Pydantic ### Constrain with Pydantic
@ -151,7 +153,7 @@ json_schema = {
} }
data = { data = {
"inputs": "[INST]convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park [/INST]", "inputs": "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park",
"parameters": { "parameters": {
"max_new_tokens": 200, "max_new_tokens": 200,
"repetition_penalty": 1.3, "repetition_penalty": 1.3,

View File

@ -1,5 +1,6 @@
## Speculation ## Speculation
Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea. Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea.
The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid. The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid.
@ -36,7 +37,7 @@ In order to use medusa models in TGI, simply point to a medusa enabled model, an
If you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`. If you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`.
Ngram works by trying to find in the previous sequence existing tokens that match, and use those as speculation. N-gram works by trying to find matching tokens in the previous sequence, and use those as speculation for generating new tokens. For example, if the tokens "np.mean" appear multiple times in the sequence, the model can speculate that the next continuation of the tokens "np." is probably also "mean".
This is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much. This is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much.

View File

@ -15,7 +15,7 @@ Token streaming is the mode in which the server returns the tokens one by one as
/> />
</div> </div>
With token streaming, the server can start returning the tokens one by one before having to generate the whole response. Users can have a sense of the generation's quality earlier than the end of the generation. This has different positive effects: With token streaming, the server can start returning the tokens one by one before having to generate the whole response. Users can have a sense of the generation's quality before the end of the generation. This has different positive effects:
* Users can get results orders of magnitude earlier for extremely long queries. * Users can get results orders of magnitude earlier for extremely long queries.
* Seeing something in progress allows users to stop the generation if it's not going in the direction they expect. * Seeing something in progress allows users to stop the generation if it's not going in the direction they expect.
@ -116,7 +116,7 @@ curl -N 127.0.0.1:8080/generate_stream \
First, we need to install the `@huggingface/inference` library. First, we need to install the `@huggingface/inference` library.
`npm install @huggingface/inference` `npm install @huggingface/inference`
If you're using the free Inference API, you can use `HfInference`. If you're using inference endpoints, you can use `HfInferenceEndpoint`. Let's If you're using the free Inference API, you can use `HfInference`. If you're using inference endpoints, you can use `HfInferenceEndpoint`.
We can create a `HfInferenceEndpoint` providing our endpoint URL and credential. We can create a `HfInferenceEndpoint` providing our endpoint URL and credential.

View File

@ -18,8 +18,8 @@ Text Generation Inference implements many optimizations and features, such as:
- Logits warper (temperature scaling, top-p, top-k, repetition penalty) - Logits warper (temperature scaling, top-p, top-k, repetition penalty)
- Stop sequences - Stop sequences
- Log probabilities - Log probabilities
- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output.
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance. - Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance.
- [Guidance](../conceptual/guidance): Enable function calling and tool-use by forcing the model to generate structured outputs based on your own predefined output schemas.
Text Generation Inference is used in production by multiple projects, such as: Text Generation Inference is used in production by multiple projects, such as:

View File

@ -293,6 +293,7 @@ def launcher(event_loop):
dtype: Optional[str] = None, dtype: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
max_input_length: Optional[int] = None, max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
@ -334,6 +335,9 @@ def launcher(event_loop):
if max_input_length: if max_input_length:
args.append("--max-input-length") args.append("--max-input-length")
args.append(str(max_input_length)) args.append(str(max_input_length))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
if max_total_tokens: if max_total_tokens:
args.append("--max-total-tokens") args.append("--max-total-tokens")
args.append(str(max_total_tokens)) args.append(str(max_total_tokens))
@ -371,6 +375,7 @@ def launcher(event_loop):
dtype: Optional[str] = None, dtype: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
max_input_length: Optional[int] = None, max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
@ -395,6 +400,9 @@ def launcher(event_loop):
if max_input_length: if max_input_length:
args.append("--max-input-length") args.append("--max-input-length")
args.append(str(max_input_length)) args.append(str(max_input_length))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
if max_total_tokens: if max_total_tokens:
args.append("--max-total-tokens") args.append("--max-total-tokens")
args.append(str(max_total_tokens)) args.append(str(max_total_tokens))

View File

@ -17,7 +17,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native", "system_fingerprint": "2.0.1-native",
"usage": { "usage": {
"completion_tokens": 100, "completion_tokens": 100,
"prompt_tokens": 60, "prompt_tokens": 60,

View File

@ -29,7 +29,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native", "system_fingerprint": "2.0.1-native",
"usage": { "usage": {
"completion_tokens": 36, "completion_tokens": 36,
"prompt_tokens": 8, "prompt_tokens": 8,

View File

@ -12,7 +12,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -27,7 +27,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -42,7 +42,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -57,7 +57,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -72,7 +72,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -87,7 +87,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -102,7 +102,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -117,7 +117,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -132,7 +132,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -147,7 +147,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -162,7 +162,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -177,7 +177,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -192,7 +192,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -207,7 +207,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -222,7 +222,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -237,7 +237,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -252,7 +252,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -267,7 +267,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -282,7 +282,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -297,7 +297,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -312,7 +312,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -327,7 +327,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -342,7 +342,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -357,7 +357,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -372,7 +372,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -387,7 +387,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -402,7 +402,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -417,7 +417,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -432,7 +432,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -447,7 +447,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -462,7 +462,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -477,7 +477,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -492,7 +492,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -507,7 +507,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -522,7 +522,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -537,7 +537,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -552,7 +552,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -567,7 +567,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -582,7 +582,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
}, },
{ {
"choices": [ "choices": [
@ -597,6 +597,6 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
} }
] ]

View File

@ -11,7 +11,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native", "system_fingerprint": "2.0.1-native",
"usage": { "usage": {
"completion_tokens": 5, "completion_tokens": 5,
"prompt_tokens": 6, "prompt_tokens": 6,

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -8.5625,
"text": "Test"
},
{
"id": 2159,
"logprob": -10.78125,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 288,
"logprob": -0.2854004,
"special": false,
"text": "ing"
},
{
"id": 264,
"logprob": -0.37573242,
"special": false,
"text": " a"
},
{
"id": 633,
"logprob": -0.09301758,
"special": false,
"text": " new"
},
{
"id": 4480,
"logprob": -0.3322754,
"special": false,
"text": " feature"
},
{
"id": 297,
"logprob": -0.8510742,
"special": false,
"text": " in"
},
{
"id": 272,
"logprob": -0.13464355,
"special": false,
"text": " the"
},
{
"id": 2039,
"logprob": 0.0,
"special": false,
"text": " game"
},
{
"id": 28723,
"logprob": -0.89990234,
"special": false,
"text": "."
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "Test requesting a new feature in the game.\n\n"
}

View File

@ -0,0 +1,73 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 330,
"logprob": -0.13000488,
"special": false,
"text": " A"
},
{
"id": 13088,
"logprob": -0.6713867,
"special": false,
"text": " chicken"
},
{
"id": 349,
"logprob": -0.2980957,
"special": false,
"text": " is"
},
{
"id": 6398,
"logprob": -0.060638428,
"special": false,
"text": " sitting"
},
{
"id": 356,
"logprob": -0.27319336,
"special": false,
"text": " on"
},
{
"id": 264,
"logprob": -0.140625,
"special": false,
"text": " a"
},
{
"id": 17972,
"logprob": -0.040405273,
"special": false,
"text": " pile"
},
{
"id": 302,
"logprob": -0.0002708435,
"special": false,
"text": " of"
},
{
"id": 2445,
"logprob": -0.095336914,
"special": false,
"text": " money"
},
{
"id": 28723,
"logprob": -0.0068359375,
"special": false,
"text": "."
}
],
"top_tokens": null
},
"generated_text": " A chicken is sitting on a pile of money."
}

View File

@ -30,7 +30,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native", "system_fingerprint": "2.0.1-native",
"usage": { "usage": {
"completion_tokens": 37, "completion_tokens": 37,
"prompt_tokens": 524, "prompt_tokens": 524,

View File

@ -30,7 +30,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native", "system_fingerprint": "2.0.1-native",
"usage": { "usage": {
"completion_tokens": 37, "completion_tokens": 37,
"prompt_tokens": 524, "prompt_tokens": 524,

View File

@ -30,7 +30,7 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native", "system_fingerprint": "2.0.1-native",
"usage": { "usage": {
"completion_tokens": 48, "completion_tokens": 48,
"prompt_tokens": 320, "prompt_tokens": 320,

View File

@ -23,5 +23,5 @@
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native" "system_fingerprint": "2.0.1-native"
} }

View File

@ -0,0 +1,81 @@
import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module")
def flash_idefics2_next_handle(launcher):
with launcher(
"HuggingFaceM4/idefics2-8b",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_idefics2_next(flash_idefics2_next_handle):
await flash_idefics2_next_handle.health(300)
return flash_idefics2_next_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
chicken = get_chicken()
response = await flash_idefics2_next.generate(
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
max_new_tokens=10,
)
assert (
response.generated_text == " A chicken is sitting on a pile of money."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):
response = await flash_idefics2_next.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_load(
flash_idefics2_next, generate_load, response_snapshot
):
chicken = get_chicken()
responses = await generate_load(
flash_idefics2_next,
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
max_new_tokens=10,
n=4,
)
generated_texts = [r.generated_text for r in responses]
assert generated_texts[0] == " A chicken is sitting on a pile of money."
assert len(generated_texts) == 4
assert all([r.generated_text == generated_texts[0] for r in responses])
assert responses == response_snapshot

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-integration-tests" name = "text-generation-integration-tests"
version = "2.0.0" version = "2.0.1"
description = "Text Generation Inference integration tests" description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"] authors = ["Nicolas Patry <nicolas@huggingface.co>"]

View File

@ -1,71 +1,94 @@
import { check, randomSeed } from 'k6'; import { check } from 'k6';
import { scenario } from 'k6/execution';
import http from 'k6/http'; import http from 'k6/http';
import { Trend, Counter } from 'k6/metrics'; import { Trend, Counter } from 'k6/metrics';
import { randomItem } from 'https://jslib.k6.io/k6-utils/1.2.0/index.js';
const seed = 0; const host = __ENV.HOST;
const model_id = __ENV.MODEL_ID;
const host = __ENV.HOST || '127.0.0.1:8000';
const timePerToken = new Trend('time_per_token', true); const timePerToken = new Trend('time_per_token', true);
const tokens = new Counter('tokens'); const tokens = new Counter('tokens');
const new_tokens = new Counter('new_tokens'); const new_tokens = new Counter('new_tokens');
const input_tokens = new Counter('input_tokens'); const input_tokens = new Counter('input_tokens');
const max_new_tokens = 50;
randomSeed(seed);
// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json")) // const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json"))
const shareGPT = JSON.parse(open("small.json")) const shareGPT = JSON.parse(open("small.json"))
export function get_options(reference_latency_ms){ export function get_options() {
return { return {
thresholds: { thresholds: {
http_req_failed: ['rate==0'], http_req_failed: ['rate==0'],
time_per_token: [{ // time_per_token: [{
threshold: `p(50)<${5 * reference_latency_ms}`, // threshold: `p(50)<${5 * reference_latency_ms}`,
abortOnFail: true, // abortOnFail: true,
delayAbortEval: '10s' // delayAbortEval: '10s'
}], // }],
}, },
scenarios: { scenarios: {
load_test: { single_user: {
executor: 'constant-arrival-rate', executor: 'constant-arrival-rate',
duration: '60s', duration: '60s',
preAllocatedVUs: 10, preAllocatedVUs: 1,
rate: 10, rate: 1,
timeUnit: '1s', timeUnit: '1s',
}, },
// load_test: {
// executor: 'constant-arrival-rate',
// duration: '60s',
// preAllocatedVUs: 100,
// rate: 1,
// timeUnit: '1s',
// },
// breakpoint: {
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows
// preAllocatedVUs: 1000,
// stages: [
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
// ],
// },
// throughput: {
// executor: 'shared-iterations',
// vus: 100,
// iterations: 200,
// maxDuration: '40s',
// },
}, },
}; };
} }
function generate_payload(gpt, max_new_tokens) {
const input = gpt["conversations"][0]["value"];
return { "messages": [{ "role": "user", "content": input }], "temperature": 0, "model": `${model_id}`, "max_tokens": max_new_tokens }
}
export function run(host, generate_payload, max_new_tokens) { export const options = get_options();
const headers = {'Content-Type': 'application/json'};
const query = randomItem(shareGPT); export default function run() {
const payload = JSON.stringify(generate_payload(query)); const headers = { 'Content-Type': 'application/json' };
const res = http.post(`http://${host}/generate`, payload, { const query = shareGPT[scenario.iterationInTest % shareGPT.length];
const payload = JSON.stringify(generate_payload(query, max_new_tokens));
const res = http.post(`http://${host}/v1/chat/completions`, payload, {
headers, headers,
}); });
if(res.status >= 400 && res.status < 500){ if (res.status >= 400 && res.status < 500) {
return; return;
} }
check(res, { check(res, {
'Post status is 200': (r) => res.status === 200, 'Post status is 200': (res) => res.status === 200,
}); });
const duration = res.timings.duration; const duration = res.timings.duration;
if (res.status === 200) { if (res.status === 200) {
const body = res.json(); const body = res.json();
const n_tokens = body.details.tokens.length; const completion_tokens = body.usage.completion_tokens;
const latency_ms_per_token = duration / n_tokens; const latency_ms_per_token = duration / completion_tokens;
timePerToken.add(latency_ms_per_token); timePerToken.add(latency_ms_per_token);
const latency_in_s = latency_ms_per_token / 1000; const prompt_tokens = body.usage.prompt_tokens;
const individual_throughput = 1 / latency_in_s; input_tokens.add(prompt_tokens);
const _input_tokens = body.details.prefill.length; new_tokens.add(completion_tokens);
tokens.add(n_tokens + _input_tokens); tokens.add(completion_tokens + prompt_tokens);
input_tokens.add(_input_tokens);
new_tokens.add(n_tokens);
} }
} }

View File

@ -1,17 +0,0 @@
import { get_options, run } from "./common.js";
const reference_latency_ms = 70;
const host = __ENV.HOST || '127.0.0.1:8000';
const max_new_tokens = 50;
function generate_payload(gpt){
const input = gpt["conversations"][0]["value"];
return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "decoder_input_details": true}}
}
export const options = get_options(reference_latency_ms);
export default function(){
run(host, generate_payload, max_new_tokens);
}

View File

@ -1,17 +0,0 @@
import { get_options, run } from "./common.js";
const reference_latency_ms = 22;
const host = __ENV.HOST || '127.0.0.1:8000';
const max_new_tokens = 50;
function generate_payload(gpt){
const input = gpt["conversations"][0]["value"];
return {"prompt": input, "temperature": 0.5, "ignore_eos": true}
}
export const options = get_options(reference_latency_ms);
export default function(){
run(host, generate_payload, max_new_tokens);
}

View File

@ -21,7 +21,7 @@ axum-tracing-opentelemetry = "0.14.1"
text-generation-client = { path = "client" } text-generation-client = { path = "client" }
clap = { version = "4.4.5", features = ["derive", "env"] } clap = { version = "4.4.5", features = ["derive", "env"] }
futures = "0.3.28" futures = "0.3.28"
hf-hub = { version = "0.3.0", features = ["tokio"] } hf-hub = { workspace = true }
jsonschema = { version = "0.17.1", features = ["draft202012"] } jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.21.1" metrics = "0.21.1"
metrics-exporter-prometheus = { version = "0.12.1", features = [] } metrics-exporter-prometheus = { version = "0.12.1", features = [] }
@ -33,7 +33,7 @@ reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188" serde = "1.0.188"
serde_json = "1.0.107" serde_json = "1.0.107"
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { version = "0.15.1", features = ["http"] } tokenizers = { workspace = true}
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14" tokio-stream = "0.1.14"
tower-http = { version = "0.4.4", features = ["cors"] } tower-http = { version = "0.4.4", features = ["cors"] }

View File

@ -114,8 +114,12 @@ impl Client {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens); let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new(); let mut inputs = String::new();
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=");
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)");
}
requests.push(Request { requests.push(Request {
id: 0, id: 0,

View File

@ -57,6 +57,31 @@ fn select_best_resolution(
best_fit.unwrap_or((original_height, original_width)) best_fit.unwrap_or((original_height, original_width))
} }
fn get_unpadded_features(
height: usize,
width: usize,
npatches: usize,
num_patch_height: usize,
num_patch_width: usize,
) -> (usize, usize) {
let current_height = npatches * num_patch_height;
let current_width = npatches * num_patch_width;
let aspect_ratio: f64 = width as f64 / height as f64;
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
let new_height = (height * current_width) / width;
(new_height, current_width)
} else {
let new_width = (width * current_height) / height;
(current_height, new_width)
};
let unpadded_features = current_height * current_width;
let newline_features = current_height;
(unpadded_features, newline_features)
}
impl LlavaNext { impl LlavaNext {
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize { pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let image_size = self.vision_config.image_size; let image_size = self.vision_config.image_size;
@ -65,11 +90,9 @@ impl LlavaNext {
let npatches = image_size / patch_size; let npatches = image_size / patch_size;
let (num_patch_height, num_patch_width) = let (num_patch_height, num_patch_width) =
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size); get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
// Ceil
let height_of_patch = (height * npatches + width - 1) / width; let (unpadded_features, newline_features) =
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width; get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);
// They are only added after width
let newline_features = height_of_patch * num_patch_width;
// The base patch covers the entire image // The base patch covers the entire image
let base_features = npatches.pow(2); let base_features = npatches.pow(2);
unpadded_features + newline_features + base_features unpadded_features + newline_features + base_features
@ -84,6 +107,17 @@ pub struct ClipVisionModel {
patch_size: usize, patch_size: usize,
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct Idefics2 {}
impl Idefics2 {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
320
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")] #[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
@ -92,6 +126,7 @@ pub enum Config {
ClipVisionModel(ClipVisionModel), ClipVisionModel(ClipVisionModel),
Mistral, Mistral,
Idefics, Idefics,
Idefics2(Idefics2),
Ssm, Ssm,
GptBigcode, GptBigcode,
Santacoder, Santacoder,
@ -146,13 +181,17 @@ mod test {
], ],
}; };
let slots = config.get_number_of_features(20, 20);
assert_eq!(slots, 1176);
let slots = config.get_number_of_features(640, 640); let slots = config.get_number_of_features(640, 640);
assert_eq!(slots, 2928); assert_eq!(slots, 2928);
let slots = config.get_number_of_features(480, 640); let slots = config.get_number_of_features(480, 640);
assert_eq!(slots, 2340); assert_eq!(slots, 2340);
let slots = config.get_number_of_features(899, 1024); let slots = config.get_number_of_features(899, 1024);
assert_eq!(slots, 2732); assert_eq!(slots, 2634);
let slots = config.get_number_of_features(1024, 899); let slots = config.get_number_of_features(1024, 899);
assert_eq!(slots, 3320); assert_eq!(slots, 2640);
let slots = config.get_number_of_features(1067, 1600);
assert_eq!(slots, 2144);
} }
} }

View File

@ -73,9 +73,9 @@ pub struct HubTokenizerConfig {
} }
impl HubTokenizerConfig { impl HubTokenizerConfig {
pub fn from_file(filename: &std::path::Path) -> Self { pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
let content = std::fs::read_to_string(filename).unwrap(); let content = std::fs::read_to_string(filename).ok()?;
serde_json::from_str(&content).unwrap_or_default() serde_json::from_str(&content).ok()
} }
} }
@ -116,6 +116,7 @@ mod token_serde {
)) ))
} }
} }
Value::Null => Ok(None),
_ => Err(de::Error::custom("invalid token format")), _ => Err(de::Error::custom("invalid token format")),
} }
} }
@ -168,9 +169,12 @@ pub struct Info {
#[derive(Clone, Debug, Deserialize, ToSchema, Default)] #[derive(Clone, Debug, Deserialize, ToSchema, Default)]
pub(crate) struct GenerateParameters { pub(crate) struct GenerateParameters {
/// Generate best_of sequences and return the one if the highest token logprobs.
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
pub best_of: Option<usize>, pub best_of: Option<usize>,
/// The value used to module the logits distribution.
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0.0, exclusive_minimum = 0.0,
@ -179,6 +183,9 @@ pub(crate) struct GenerateParameters {
example = 0.5 example = 0.5
)] )]
pub temperature: Option<f32>, pub temperature: Option<f32>,
/// The parameter for repetition penalty. 1.0 means no penalty.
/// See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0.0, exclusive_minimum = 0.0,
@ -187,6 +194,10 @@ pub(crate) struct GenerateParameters {
example = 1.03 example = 1.03
)] )]
pub repetition_penalty: Option<f32>, pub repetition_penalty: Option<f32>,
/// The parameter for frequency penalty. 1.0 means no penalty
/// Penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = -2.0, exclusive_minimum = -2.0,
@ -195,9 +206,13 @@ pub(crate) struct GenerateParameters {
example = 0.1 example = 0.1
)] )]
pub frequency_penalty: Option<f32>, pub frequency_penalty: Option<f32>,
/// The number of highest probability vocabulary tokens to keep for top-k-filtering.
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
pub top_k: Option<i32>, pub top_k: Option<i32>,
/// Top-p value for nucleus sampling.
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0.0, exclusive_minimum = 0.0,
@ -207,6 +222,9 @@ pub(crate) struct GenerateParameters {
example = 0.95 example = 0.95
)] )]
pub top_p: Option<f32>, pub top_p: Option<f32>,
/// Typical Decoding mass
/// See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0.0, exclusive_minimum = 0.0,
@ -216,30 +234,48 @@ pub(crate) struct GenerateParameters {
example = 0.95 example = 0.95
)] )]
pub typical_p: Option<f32>, pub typical_p: Option<f32>,
/// Activate logits sampling.
#[serde(default)] #[serde(default)]
#[schema(default = "false", example = true)] #[schema(default = "false", example = true)]
pub do_sample: bool, pub do_sample: bool,
/// Maximum number of tokens to generate.
#[serde(default = "default_max_new_tokens")] #[serde(default = "default_max_new_tokens")]
#[schema(nullable = true, default = "100", example = "20")] #[schema(nullable = true, default = "100", example = "20")]
pub max_new_tokens: Option<u32>, pub max_new_tokens: Option<u32>,
/// Whether to prepend the prompt to the generated text
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = false)] #[schema(nullable = true, default = "null", example = false)]
pub return_full_text: Option<bool>, pub return_full_text: Option<bool>,
/// Stop generating tokens if a member of `stop` is generated.
#[serde(default)] #[serde(default)]
#[schema(inline, max_items = 4, example = json ! (["photographer"]))] #[schema(inline, max_items = 4, example = json ! (["photographer"]))]
pub stop: Vec<String>, pub stop: Vec<String>,
/// Truncate inputs tokens to the given size.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub truncate: Option<usize>, pub truncate: Option<usize>,
/// Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).
#[serde(default)] #[serde(default)]
#[schema(default = "false", example = true)] #[schema(default = "false", example = true)]
pub watermark: bool, pub watermark: bool,
/// Whether to return generation details.
#[serde(default)] #[serde(default)]
#[schema(default = "true")] #[schema(default = "true")]
pub details: bool, pub details: bool,
/// Whether to return decoder input token logprobs and ids.
#[serde(default)] #[serde(default)]
#[schema(default = "true")] #[schema(default = "false")]
pub decoder_input_details: bool, pub decoder_input_details: bool,
/// Random sampling seed.
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0, exclusive_minimum = 0,
@ -248,10 +284,15 @@ pub(crate) struct GenerateParameters {
example = "null" example = "null"
)] )]
pub seed: Option<u64>, pub seed: Option<u64>,
/// The number of highest probability vocabulary tokens to keep for top-n-filtering.
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>, pub top_n_tokens: Option<u32>,
/// Grammar constraints for the generation.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub grammar: Option<GrammarType>, pub grammar: Option<GrammarType>,
} }
@ -548,7 +589,9 @@ pub(crate) struct ChatCompletionChoice {
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionDelta { pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")] #[schema(example = "user")]
pub role: String, // TODO Modify this to a true enum.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
pub content: Option<String>, pub content: Option<String>,
@ -582,6 +625,31 @@ impl ChatCompletionChunk {
logprobs: Option<ChatCompletionLogprobs>, logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>, finish_reason: Option<String>,
) -> Self { ) -> Self {
let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta {
role: Some("assistant".to_string()),
content: Some(delta),
tool_calls: None,
},
(None, Some(tool_calls)) => ChatCompletionDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tool_calls[0].to_string(),
},
}),
},
(None, None) => ChatCompletionDelta {
role: None,
content: None,
tool_calls: None,
},
};
Self { Self {
id: String::new(), id: String::new(),
object: "text_completion".to_string(), object: "text_completion".to_string(),
@ -590,19 +658,7 @@ impl ChatCompletionChunk {
system_fingerprint, system_fingerprint,
choices: vec![ChatCompletionChoice { choices: vec![ChatCompletionChoice {
index: 0, index: 0,
delta: ChatCompletionDelta { delta,
role: "assistant".to_string(),
content: delta,
tool_calls: tool_calls.map(|tc| DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tc[0].to_string(),
},
}),
},
logprobs, logprobs,
finish_reason, finish_reason,
}], }],

View File

@ -1,7 +1,7 @@
use axum::http::HeaderValue; use axum::http::HeaderValue;
use clap::Parser; use clap::Parser;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Repo, RepoType}; use hf_hub::{Cache, Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace; use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::trace::Sampler;
@ -11,7 +11,7 @@ use opentelemetry_otlp::WithExportConfig;
use std::fs::File; use std::fs::File;
use std::io::BufReader; use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path; use std::path::{Path, PathBuf};
use text_generation_client::{ClientError, ShardedClient}; use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::config::Config; use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig}; use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
@ -162,7 +162,6 @@ async fn main() -> Result<(), RouterError> {
// Tokenizer instance // Tokenizer instance
// This will only be used to validate payloads // This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name); let local_path = Path::new(&tokenizer_name);
let local_model = local_path.exists() && local_path.is_dir();
// Shared API builder initialization // Shared API builder initialization
let api_builder = || { let api_builder = || {
@ -181,109 +180,113 @@ async fn main() -> Result<(), RouterError> {
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed // Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api { let api = if use_api {
tracing::info!("Using the Hugging Face API"); if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
match api_builder().build() { let cache = Cache::default();
Ok(api) => Some(api), tracing::warn!("Offline mode active using cache defaults");
Err(_) => { Type::Cache(cache)
tracing::warn!("Unable to build the Hugging Face API"); } else {
None tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
} }
} }
} else { } else {
None Type::None
}; };
// Load tokenizer and model info // Load tokenizer and model info
let (tokenizer, model_info, config) = if local_model { let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok(); Type::None => (
let model_info = HubModelInfo { Some(local_path.join("tokenizer.json")),
model_id: tokenizer_name.to_string(), Some(local_path.join("config.json")),
sha: None, Some(local_path.join("tokenizer_config.json")),
pipeline_tag: None, None,
}; ),
let config: Option<Config> = std::fs::read_to_string(local_path.join("config.json")) Type::Api(api) => {
.ok() let api_repo = api.repo(Repo::with_revision(
.as_ref() tokenizer_name.to_string(),
.and_then(|c| serde_json::from_str(c).ok()); RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
(tokenizer, model_info, config) let tokenizer_filename = match api_repo.get("tokenizer.json").await {
} else if let Some(api) = api.clone() { Ok(tokenizer_filename) => Some(tokenizer_filename),
let api_repo = api.repo(Repo::with_revision( Err(_) => get_base_tokenizer(&api, &api_repo).await,
tokenizer_name.to_string(), };
RepoType::Model, let config_filename = api_repo.get("config.json").await.ok();
revision.clone().unwrap_or_else(|| "main".to_string()), let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
));
let tokenizer = match api_repo.get("tokenizer.json").await { let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(), Some(model_info)
Err(_) => get_base_tokenizer(&api, &api_repo).await, } else {
}; tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| { };
std::fs::read_to_string(filename) (
.ok() tokenizer_filename,
.as_ref() config_filename,
.and_then(|c| { tokenizer_config_filename,
let config: Result<Config, _> = serde_json::from_str(c); model_info,
if let Err(err) = &config { )
tracing::warn!("Could not parse config {err:?}"); }
} Type::Cache(cache) => {
config.ok() let repo = cache.repo(Repo::with_revision(
}) tokenizer_name.to_string(),
}); RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| { ));
tracing::warn!("Could not retrieve model info from the Hugging Face hub."); (
HubModelInfo { repo.get("tokenizer.json"),
model_id: tokenizer_name.to_string(), repo.get("config.json"),
sha: None, repo.get("tokenizer_config.json"),
pipeline_tag: None, None,
} )
});
(tokenizer, model_info, config)
} else {
// No API and no local model
return Err(RouterError::ArgumentValidation(
"No local model found and no revision specified".to_string(),
));
};
tracing::info!("Using config {config:?}");
// Load tokenizer config if found locally, or check if we can get it from the API if needed
let tokenizer_config = if let Some(path) = tokenizer_config_path {
tracing::info!("Using local tokenizer config from user specified path");
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
} else if local_model {
tracing::info!("Using local tokenizer config");
HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
} else {
match api {
Some(api) => {
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
let repo = Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or("main".to_string()),
);
get_tokenizer_config(&api.repo(repo))
.await
.unwrap_or_else(|| {
tracing::warn!(
"Could not retrieve tokenizer config from the Hugging Face hub."
);
HubTokenizerConfig::default()
})
}
None => {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
}
} }
}; };
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
});
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
tracing::info!("Using config {config:?}");
if tokenizer.is_none() { if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled"); tracing::warn!("Rust input length validation and truncation is disabled");
@ -480,7 +483,7 @@ pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
} }
/// get base tokenizer /// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokenizer> { pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
let config_filename = api_repo.get("config.json").await.ok()?; let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer. // Open the file in read-only mode with buffer.
@ -497,8 +500,7 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokeniz
"main".to_string(), "main".to_string(),
)); ));
let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?; api_base_repo.get("tokenizer.json").await.ok()
Tokenizer::from_file(tokenizer_filename).ok()
} else { } else {
None None
} }

View File

@ -1000,6 +1000,7 @@ async fn chat_completions(
tools, tools,
tool_choice, tool_choice,
tool_prompt, tool_prompt,
temperature,
.. ..
} = req; } = req;
@ -1008,6 +1009,11 @@ async fn chat_completions(
let logprobs = logprobs.unwrap_or(false); let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt.unwrap_or_default(); let tool_prompt = tool_prompt.unwrap_or_default();
let stop = stop.unwrap_or_default(); let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
// extract tool grammar if present // extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
@ -1054,13 +1060,13 @@ async fn chat_completions(
inputs: inputs.to_string(), inputs: inputs.to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature: req.temperature, temperature,
repetition_penalty, repetition_penalty,
frequency_penalty: req.frequency_penalty, frequency_penalty: req.frequency_penalty,
top_k: None, top_k: None,
top_p: req.top_p, top_p: req.top_p,
typical_p: None, typical_p: None,
do_sample: true, do_sample,
max_new_tokens, max_new_tokens,
return_full_text: None, return_full_text: None,
stop, stop,
@ -1097,7 +1103,13 @@ async fn chat_completions(
let (content, tool_calls) = if tool_grammar.is_some() { let (content, tool_calls) = if tool_grammar.is_some() {
(None, Some(vec![stream_token.token.text])) (None, Some(vec![stream_token.token.text]))
} else { } else {
(Some(stream_token.token.text), None) let content = if !stream_token.token.special {
Some(stream_token.token.text)
} else {
None
};
(content, None)
}; };
event event

View File

@ -540,7 +540,57 @@ fn prepare_input(
inputs = modified_inputs; inputs = modified_inputs;
tokenizer_query tokenizer_query
} }
Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(), Some(Config::Idefics2(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str("<fake_token_around_image>");
tokenizer_query.push_str(&"<image>".repeat(slots));
tokenizer_query.push_str("<fake_token_around_image>");
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = 1;
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
_ => inputs.clone(), _ => inputs.clone(),
}; };

1081
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-server" name = "text-generation-server"
version = "2.0.0" version = "2.0.1"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
@ -24,9 +24,9 @@ opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0" opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2" hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97" sentencepiece = "^0.1.97"
tokenizers = "^0.15.0" tokenizers = "^0.19.1"
huggingface-hub = "^0.19.3" huggingface-hub = "^0.19.3"
transformers = "^4.39" transformers = "^4.40"
einops = "^0.6.1" einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true } texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true } datasets = { version = "^2.14.0", optional = true }

View File

@ -5,7 +5,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.3 ; python_version >= "3.9" and python_version < "3.13" filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
@ -14,7 +14,7 @@ grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -30,15 +30,15 @@ packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13" regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.2 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.2.0 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.39.3 ; python_version >= "3.9" and python_version < "3.13" transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -5,7 +5,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.3 ; python_version >= "3.9" and python_version < "3.13" filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
@ -14,7 +14,7 @@ grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -30,15 +30,15 @@ packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13" regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.2 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.2.0 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.39.3 ; python_version >= "3.9" and python_version < "3.13" transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -68,6 +68,7 @@ try:
) )
from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.llava_next import LlavaNext from text_generation_server.models.llava_next import LlavaNext
from text_generation_server.models.idefics2 import Idefics2
from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_phi import FlashPhi
@ -327,7 +328,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == "llama" or model_type == "baichuan": elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashLlama( return FlashLlama(
model_id, model_id,
@ -579,6 +580,18 @@ def get_model(
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "idefics2":
if FLASH_ATTENTION:
return Idefics2(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "llava_next": if model_type == "llava_next":
if FLASH_ATTENTION: if FLASH_ATTENTION:

View File

@ -45,58 +45,6 @@ if IS_ROCM_SYSTEM:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class LlamaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=10000.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
@ -108,6 +56,13 @@ def load_attention(config, prefix, weights):
weights=weights, weights=weights,
bias=False, bias=False,
) )
elif config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=False,
)
else: else:
return TensorParallelColumnLinear.load_multi( return TensorParallelColumnLinear.load_multi(
config, config,
@ -265,13 +220,21 @@ class LlamaMLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi( if config.model_type == "phi3":
config, self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], config,
weights=weights, prefix=f"{prefix}.gate_up_proj",
dim=0, weights=weights,
bias=False, bias=False,
) )
else:
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load( self.down_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",

View File

@ -409,23 +409,29 @@ class MistralModel(torch.nn.Module):
class FlashMistralForCausalLM(torch.nn.Module): class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, name=None):
if name is None:
name = "model"
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix=( prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" f"{name}.embed_tokens"
if not prefix
else f"{prefix}.{name}.embed_tokens"
), ),
weights=weights, weights=weights,
) )
self.model = MistralModel( self.model = MistralModel(
prefix="model" if not prefix else f"{prefix}.model", prefix=name if not prefix else f"{prefix}.{name}",
config=config, config=config,
weights=weights, weights=weights,
) )
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head", # TODO dirty hack for idefics2.
prefix=(
"lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
),
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window

View File

@ -0,0 +1,829 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Idefics2 model."""
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
import math
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Idefics2VisionEmbeddings(nn.Module):
"""
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
resolution.
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
which allows treating images in their native aspect ratio and without the need to resize them to the same
fixed size. In particular, we start from the original pre-trained SigLIP model
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
"""
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.patch_embedding.bias = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
def forward(
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size,
max_im_w // self.patch_size,
)
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(
fractional_coords_h, boundaries, right=True
)
bucket_coords_w = torch.bucketize(
fractional_coords_w, boundaries, right=True
)
pos_ids = (
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
class Idefics2VisionAttention(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = self.embed_dim // self.num_heads
if self.head_size * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_size**-0.5
self.dropout = config.attention_dropout
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
self.out_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, q_len, _ = hidden_states.size()
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_size * self.num_heads,
self.head_size * self.num_heads,
self.head_size * self.num_heads,
],
dim=2,
)
query_states = query_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
key_states = key_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
value_states = value_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
)
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class Idefics2VisionMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Idefics2EncoderLayer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics2VisionAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
)
self.mlp = Idefics2VisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Idefics2Encoder(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
Idefics2EncoderLayer(
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
)
for i in range(config.num_hidden_layers)
]
)
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
)
return hidden_states
class Idefics2VisionTransformer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embeddings = Idefics2VisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.encoder = Idefics2Encoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
):
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
patch_size = self.config.patch_size
patch_attention_mask = torch.ones(
(
batch_size,
pixel_values.size(2) // patch_size,
pixel_values.size(3) // patch_size,
)
)
patch_attention_mask = patch_attention_mask.to(
dtype=torch.bool, device=pixel_values.device
)
hidden_states = self.embeddings(
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
patch_attention_mask = None
else:
patch_attention_mask = _prepare_4d_attention_mask(
patch_attention_mask, hidden_states.dtype
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=patch_attention_mask,
)
last_hidden_state = encoder_outputs
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
class Idefics2MLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.text_config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
def forward(self, hidden_states):
start_shape = hidden_states.shape[:-1]
gate_up_states = self.gate_up_proj(hidden_states)
intermediate_size = gate_up_states.shape[-1] // 2
gate_up_states = gate_up_states.view(-1, 2, intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
).view(*start_shape, -1)
class Idefics2RMSNorm(nn.Module):
def __init__(self, prefix, weights, eps):
"""
Idefics2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
)
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class Idefics2PerceiverAttention(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.layer_idx = None
self.hidden_size = config.text_config.hidden_size
self.num_heads = config.perceiver_config.resampler_n_heads
self.head_size = config.perceiver_config.resampler_head_dim
self.num_key_value_heads = config.perceiver_config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.attention_dropout = config.perceiver_config.attention_dropout
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
self.num_key_value_heads // weights.process_group.size()
)
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=False,
)
self.kv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
)
self.is_causal = False
def forward(
self,
latents: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = latents.size()
kv_seq_len = q_len + context.size()[1]
hidden_states = torch.concat([context, latents], dim=-2)
query_states = self.q_proj(latents)
kv = self.kv(hidden_states)
key_states, value_states = kv.split(
[
self.head_size * self.num_key_value_heads,
self.head_size * self.num_key_value_heads,
],
dim=2,
)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_size
).transpose(1, 2)
key_states = key_states.view(
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
).transpose(1, 2)
value_states = value_states.view(
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
).transpose(1, 2)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_size)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)
attn_output = self.o_proj(attn_output)
return attn_output
class Idefics2PerceiverLayer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.hidden_size = config.text_config.hidden_size
self.n_latents = config.perceiver_config.resampler_n_latents
self.depth = config.perceiver_config.resampler_depth
self.rms_norm_eps = config.text_config.rms_norm_eps
self.input_latents_norm = Idefics2RMSNorm(
prefix=f"{prefix}.input_latents_norm",
weights=weights,
eps=self.rms_norm_eps,
)
self.input_context_norm = Idefics2RMSNorm(
prefix=f"{prefix}.input_context_norm",
weights=weights,
eps=self.rms_norm_eps,
)
self.self_attn = Idefics2PerceiverAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.post_attention_layernorm = Idefics2RMSNorm(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=self.rms_norm_eps,
)
self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
def forward(
self,
latents: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
"""
Args:
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
"""
residual = latents
latents = self.input_latents_norm(latents)
context = self.input_context_norm(context)
latents = self.self_attn(
latents=latents,
context=context,
attention_mask=attention_mask,
)
latents = residual + latents
residual = latents
latents = self.post_attention_layernorm(latents)
latents = self.mlp(latents)
latents = residual + latents
return latents
class Idefics2PerceiverResampler(nn.Module):
def __init__(self, prefix, config, weights) -> None:
super().__init__()
self.hidden_size = config.text_config.hidden_size
self.hidden_act = config.perceiver_config.hidden_act
self.n_latents = config.perceiver_config.resampler_n_latents
self.depth = config.perceiver_config.resampler_depth
self.rms_norm_eps = config.text_config.rms_norm_eps
# Create Latents for Perceiver
self.latents = weights.get_tensor(f"{prefix}.latents")
# Create Transformer Blocks
self.layers = nn.ModuleList(
[
Idefics2PerceiverLayer(
prefix=f"{prefix}.layers.{idx}", config=config, weights=weights
)
for idx in range(self.depth)
]
)
self.norm = Idefics2RMSNorm(
prefix=f"{prefix}.norm",
weights=weights,
eps=config.text_config.rms_norm_eps,
)
def forward(
self,
context: torch.Tensor,
attention_mask,
) -> torch.Tensor:
# seq embed -> bsz seq embed
latents = self.latents.unsqueeze(0).expand(
(context.shape[0], *self.latents.size())
)
latent_attention_mask = torch.ones(
(attention_mask.size(0), latents.size(1)),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
attention_mask = _prepare_4d_attention_mask(
attention_mask, latents.dtype, tgt_len=self.n_latents
)
compressed_context = latents
for perceiver_layer in self.layers:
compressed_context = perceiver_layer(
compressed_context,
context,
attention_mask=attention_mask,
)
compressed_context = self.norm(compressed_context)
return compressed_context
class Idefics2Connector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.modality_projection = Idefics2MLP(
prefix=f"{prefix}.modality_projection", config=config, weights=weights
)
self.perceiver_resampler = Idefics2PerceiverResampler(
prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights
)
def forward(self, image_hidden_states, attention_mask):
image_hidden_states = self.modality_projection(image_hidden_states)
image_hidden_states = self.perceiver_resampler(
context=image_hidden_states, attention_mask=attention_mask
)
return image_hidden_states
class Idefics2ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
config.vision_config.use_medusa = config.use_medusa
config.text_config.quantize = config.quantize
config.text_config.use_medusa = config.use_medusa
vision_config = config.vision_config
self.text_model = load_text_model(
prefix="model" if not prefix else f"{prefix}.model",
config=config.text_config,
weights=weights,
name="text_model",
)
self.dtype = weights.dtype
self.vision_model = Idefics2VisionTransformer(
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
config=vision_config,
weights=weights,
)
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
self.config = config
self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = config.image_token_id
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
# mask = input_ids == self.config.image_token_index
mask = input_ids == self.config.image_token_id
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -23,6 +23,10 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
@ -105,36 +109,6 @@ class LlavaNextMultiModalProjector(nn.Module):
return hidden_states return hidden_states
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_text_model(prefix, config, weights):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
class LlavaNextForConditionalGeneration(nn.Module): class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
@ -180,7 +154,12 @@ class LlavaNextForConditionalGeneration(nn.Module):
"""In place merges in vision_embeddings with inputs_embeds.""" """In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots ! # Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds return inputs_embeds
def forward( def forward(
@ -196,6 +175,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
): ):
inputs_embeds = self.language_model.embed_tokens(input_ids) inputs_embeds = self.language_model.embed_tokens(input_ids)

View File

@ -0,0 +1,28 @@
def load_text_model(prefix, config, weights, name=None):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights, name=name)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")

View File

@ -2,14 +2,13 @@ import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from transformers.models.llama import LlamaTokenizer from transformers.models.llama import LlamaTokenizer
from typing import Optional from typing import Optional
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM, FlashLlamaForCausalLM,
LlamaConfig,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
@ -53,8 +52,17 @@ class FlashLlama(FlashCausalLM):
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = LlamaConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize

View File

@ -511,18 +511,33 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph = self.cuda_graphs.get(padded_bs, None) cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids, if cu_seqlen_prefill is None:
position_ids=position_ids, logits, speculative_logits = self.compiled_model(
cu_seqlen_prefill=cu_seqlen_prefill, input_ids=input_ids,
kv_cache=kv_cache, position_ids=position_ids,
block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill,
slots=slots, kv_cache=kv_cache,
input_lengths=input_lengths, block_tables=block_tables,
max_s=max_s, slots=slots,
prefill_cache_indices=batch.prefill_cache_indices, input_lengths=input_lengths,
lm_head_indices=lm_head_indices, max_s=max_s,
) prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
else:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
return logits, speculative_logits return logits, speculative_logits

View File

@ -0,0 +1,51 @@
import torch
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class Idefics2(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
size={"longest_edge": 448, "shortest_edge": 378},
)
super().__init__(
model_cls=Idefics2ForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)

View File

@ -1,6 +1,6 @@
import torch import torch
from typing import Optional from typing import Optional, Tuple
from transformers import ( from transformers import (
AutoProcessor, AutoProcessor,
@ -34,3 +34,13 @@ class LlavaNext(VlmCausalLM):
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)

View File

@ -27,7 +27,14 @@ class Model(ABC):
): ):
self.model = model.eval() self.model = model.eval()
self.tokenizer = tokenizer self.tokenizer = tokenizer
# all_special_ids is not set correctly if the rust tokenizer is unpacked
# TODO report this to transformers.
other_special_ids = {
id for id, token in tokenizer.added_tokens_decoder.items() if token.special
}
self.all_special_ids = set(tokenizer.all_special_ids) self.all_special_ids = set(tokenizer.all_special_ids)
self.all_special_ids.update(other_special_ids)
self.requires_padding = requires_padding self.requires_padding = requires_padding
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device

View File

@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
def image_text_replacement(image_input, config, image_id) -> str:
if config.model_type == "idefics2":
# TODO technically depends on image splitting which is not implemented.
num_features = 320
return (
"<fake_token_around_image>"
+ "<image>" * num_features
+ "<fake_token_around_image>"
)
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config)
from loguru import logger
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
return "<image>" * num_features
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def get_unpadded_features(
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
current_height = new_height
else:
new_width = (width * current_height) // height
current_width = new_width
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def get_number_of_features(height: int, width: int, config) -> int: def get_number_of_features(height: int, width: int, config) -> int:
# From config # From config
# Hardcoded for CLIP for now # Hardcoded for CLIP for now
@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
image_grid_pinpoints, image_grid_pinpoints,
image_size, image_size,
) )
unpadded_features, newline_features = get_unpadded_features(
height_of_patch = math.ceil(height / width * npatches) height, width, npatches, num_patch_height, num_patch_width
)
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
# They are only added after width
newline_features = height_of_patch * num_patch_width
# The base patch covers the entire image # The base patch covers the entire image
base_features = npatches**2 base_features = npatches**2
return unpadded_features + newline_features + base_features return unpadded_features + newline_features + base_features
@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
return image return image
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
# assert get_number_of_features(640, 640) == 2928
class VlmCausalLMBatch(FlashMistralBatch): class VlmCausalLMBatch(FlashMistralBatch):
pixel_values: Optional[List[torch.Tensor]] pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]] image_sizes: Optional[List[Tuple[int, int]]]
@classmethod @classmethod
@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def concatenate(cls, batches): def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches) batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
return batch return batch
@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def filter(self, request_ids: List[int]): def filter(self, request_ids: List[int]):
batch = super().filter(request_ids) batch = super().filter(request_ids)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
return batch return batch
@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
for r in requests: for r in requests:
chunks = split(r.inputs) chunks = split(r.inputs)
full_text = "" full_text = ""
image_id = 0
for chunk in chunks: for chunk in chunks:
if chunk["type"] == "text": if chunk["type"] == "text":
full_text += chunk["content"] full_text += chunk["content"]
@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
"Cannot process input image not starting with data:" "Cannot process input image not starting with data:"
) )
image_input = processor.image_processor(image, return_tensors="pt") image_input = processor.image_processor(image, return_tensors="pt")
height, width = image_input["image_sizes"][0] full_text += image_text_replacement(image_input, config, image_id)
num_features = get_number_of_features(height, width, config)
full_text += "<image>" * num_features
image_inputs.append(image_input) image_inputs.append(image_input)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}") raise RuntimeError(f"Invalid chunk type {chunk['type']}")
@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch_inputs, truncation=True, max_length=max_truncation batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"] )["input_ids"]
if image_inputs: if image_inputs:
image_inputs = { image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat( "pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0 [img["pixel_values"] for img in image_inputs], dim=0
), ),
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
} }
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else: else:
image_inputs = None image_inputs = None
return batch_tokenized_inputs, image_inputs return batch_tokenized_inputs, image_inputs
@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None: if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device) batch.pixel_values = image_inputs["pixel_values"].to(device=device)
batch.image_sizes = image_inputs["image_sizes"].to(device=device) if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
else: else:
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
return batch return batch
@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral):
def batch_type(self) -> Type[VlmCausalLMBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch return VlmCausalLMBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)
def forward( def forward(
self, batch: VlmCausalLMBatch self, batch: VlmCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral):
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0] bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph # Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None) bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral):
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values, pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes, image_sizes=batch.image_sizes,
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
if batch.pixel_values is not None: if batch.pixel_values is not None:
batch.pixel_values = None batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None: if batch.image_sizes is not None:
batch.image_sizes = None batch.image_sizes = None
return logits, speculative_logits return logits, speculative_logits

View File

@ -756,6 +756,19 @@ class TensorParallelHead(SuperLayer):
class TensorParallelColumnLinear(SuperLayer): class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_gate_up(
prefix, quantize=config.quantize
)
if bias:
raise NotImplementedError("packed_gate_up only implemented without bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
return cls(linear)
@classmethod @classmethod
def load_qkv(cls, config, prefix: str, weights, bias: bool): def load_qkv(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact""" """Specific method when the QKV was joined after the fact"""

View File

@ -143,6 +143,8 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
score = torch.gather(scores, 1, input_ids) score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability # if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(score < 0, score * self.penalty, score / self.penalty) score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
# set score to 0 where input_ids is a padding token
score *= input_ids.ne(0)
return scores.scatter_add_(1, input_ids, score) return scores.scatter_add_(1, input_ids, score)
@ -168,6 +170,8 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
score = -torch.where( score = -torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor score < 0, score * self.penalty_tensor, score / self.penalty_tensor
) )
# set score to 0 where input_ids is a padding token
score *= input_ids.ne(0)
return scores.scatter_add_(1, input_ids, score) return scores.scatter_add_(1, input_ids, score)

View File

@ -1,7 +1,6 @@
import torch import torch
# vllm imports from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from vllm._C import cache_ops, ops
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
@ -13,7 +12,18 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
): ):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) if IS_CUDA_SYSTEM:
from vllm._C import cache_ops
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
elif IS_ROCM_SYSTEM:
from vllm import cache_ops
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
else:
raise ValueError("vllm is not supported on your system")
def attention( def attention(
@ -55,21 +65,43 @@ def attention(
# to parallelize. # to parallelize.
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1: if use_v1:
ops.paged_attention_v1( if IS_CUDA_SYSTEM:
out, from vllm._C import ops
query,
key_cache, ops.paged_attention_v1(
value_cache, out,
kv_head_mapping, query,
softmax_scale, key_cache,
block_tables, value_cache,
input_lengths, kv_head_mapping,
block_size, softmax_scale,
max_s, block_tables,
None, input_lengths,
"auto", block_size,
1.0, max_s,
) None,
"auto",
1.0,
)
elif IS_ROCM_SYSTEM:
from vllm import attention_ops
attention_ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
else:
raise ValueError("vllm is not supported on your system")
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0 assert _PARTITION_SIZE % block_size == 0
@ -84,21 +116,46 @@ def attention(
device=out.device, device=out.device,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out, if IS_CUDA_SYSTEM:
exp_sums, from vllm._C import ops
max_logits,
tmp_output, ops.paged_attention_v2(
query, out,
key_cache, exp_sums,
value_cache, max_logits,
kv_head_mapping, tmp_output,
softmax_scale, query,
block_tables, key_cache,
input_lengths, value_cache,
block_size, kv_head_mapping,
max_s, softmax_scale,
None, block_tables,
"auto", input_lengths,
1.0, block_size,
) max_s,
None,
"auto",
1.0,
)
elif IS_ROCM_SYSTEM:
from vllm import attention_ops
attention_ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
else:
raise ValueError("vllm is not supported on your system")

View File

@ -1,5 +1,5 @@
import re import re
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Set, Union
import math import math
import torch import torch
@ -143,12 +143,22 @@ class StopSequenceCriteria:
class StoppingCriteria: class StoppingCriteria:
def __init__( def __init__(
self, self,
eos_token_id: int, eos_token_ids: Optional[Union[Set[int], int]],
stop_sequence_criterias: List[StopSequenceCriteria], stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20, max_new_tokens: int = 20,
ignore_eos_token: bool = False, ignore_eos_token: bool = False,
): ):
self.eos_token_id = eos_token_id if eos_token_ids is None:
eos_token_ids = set()
elif isinstance(eos_token_ids, int):
eos_token_ids = set([eos_token_ids])
elif isinstance(eos_token_ids, set):
eos_token_ids = eos_token_ids
else:
raise RuntimeError(
f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
)
self.eos_token_ids = eos_token_ids
self.stop_sequence_criterias = stop_sequence_criterias self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.current_tokens = 0 self.current_tokens = 0
@ -160,7 +170,10 @@ class StoppingCriteria:
if self.current_tokens >= self.max_new_tokens: if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH return True, FinishReason.FINISH_REASON_LENGTH
if not self.ignore_eos_token and last_token == self.eos_token_id: if isinstance(last_token, torch.Tensor):
last_token = last_token.item()
if not self.ignore_eos_token and last_token in self.eos_token_ids:
return True, FinishReason.FINISH_REASON_EOS_TOKEN return True, FinishReason.FINISH_REASON_EOS_TOKEN
if self.stop_sequence_criterias: if self.stop_sequence_criterias:
@ -184,8 +197,10 @@ class StoppingCriteria:
stop_sequence_criterias = [ stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
] ]
# TODO Hack because eos_token_id cannot be what we want.
eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
return StoppingCriteria( return StoppingCriteria(
tokenizer.eos_token_id, eos_token_id,
stop_sequence_criterias, stop_sequence_criterias,
pb.max_new_tokens, pb.max_new_tokens,
pb.ignore_eos_token, pb.ignore_eos_token,
@ -273,7 +288,7 @@ class HeterogeneousNextTokenChooser:
else None else None
) )
if any([x != 1.0 for x in temperature]): if any(x != 1.0 for x in temperature):
do_sample = [ do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample) sample or x != 1.0 for x, sample in zip(temperature, do_sample)
] ]
@ -281,15 +296,15 @@ class HeterogeneousNextTokenChooser:
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device) HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
) )
if any([x != 0 for x in top_k]): if any(x != 0 for x in top_k):
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)] do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device)) warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
if any([x < 1.0 for x in top_p]): if any(x < 1.0 for x in top_p):
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)] do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device)) warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
if any([x < 1.0 for x in typical_p]): if any(x < 1.0 for x in typical_p):
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)] do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device)) warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))

View File

@ -141,6 +141,12 @@ class Weights:
return weight return weight
def get_weights_col_packed_qkv(self, prefix: str, quantize: str): def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 3)
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 2)
def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
""" """
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor already alternating Q,K,V within the main tensor
@ -181,8 +187,8 @@ class Weights:
else: else:
slice_ = self._get_slice(f"{prefix}.weight") slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0] total_size = slice_.get_shape()[0]
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3" assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // 3 single_size = total_size // blocks
world_size = self.process_group.size() world_size = self.process_group.size()
rank = self.process_group.rank() rank = self.process_group.rank()
@ -192,10 +198,11 @@ class Weights:
block_size = single_size // world_size block_size = single_size // world_size
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
q = slice_[start:stop] tensors = []
k = slice_[start + single_size : stop + single_size] for i in range(blocks):
v = slice_[start + 2 * single_size : stop + 2 * single_size] tensor = slice_[start + i * single_size : stop + i * single_size]
weight = torch.cat([q, k, v], dim=0) tensors.append(tensor)
weight = torch.cat(tensors, dim=0)
weight = weight.to(device=self.device) weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype) weight = weight.to(dtype=self.dtype)
return weight return weight