mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 12:02:08 +00:00
Merge branch 'main' into mi300-compat
This commit is contained in:
commit
7502367043
248
Cargo.lock
generated
248
Cargo.lock
generated
@ -120,7 +120,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -159,7 +159,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -170,7 +170,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -449,9 +449,9 @@ checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.0.92"
|
version = "1.0.94"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2678b2e3449475e95b0aa6f9b506a28e61b3dc8996592b983695e8ebb58a8b41"
|
checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"jobserver",
|
"jobserver",
|
||||||
"libc",
|
"libc",
|
||||||
@ -510,7 +510,7 @@ dependencies = [
|
|||||||
"heck 0.5.0",
|
"heck 0.5.0",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -665,9 +665,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "darling"
|
name = "darling"
|
||||||
version = "0.14.4"
|
version = "0.20.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850"
|
checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"darling_core",
|
"darling_core",
|
||||||
"darling_macro",
|
"darling_macro",
|
||||||
@ -675,27 +675,27 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "darling_core"
|
name = "darling_core"
|
||||||
version = "0.14.4"
|
version = "0.20.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0"
|
checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fnv",
|
"fnv",
|
||||||
"ident_case",
|
"ident_case",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"strsim 0.10.0",
|
"strsim 0.10.0",
|
||||||
"syn 1.0.109",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "darling_macro"
|
name = "darling_macro"
|
||||||
version = "0.14.4"
|
version = "0.20.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
|
checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"darling_core",
|
"darling_core",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 1.0.109",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -709,33 +709,33 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "derive_builder"
|
name = "derive_builder"
|
||||||
version = "0.12.0"
|
version = "0.20.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8"
|
checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"derive_builder_macro",
|
"derive_builder_macro",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "derive_builder_core"
|
name = "derive_builder_core"
|
||||||
version = "0.12.0"
|
version = "0.20.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f"
|
checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"darling",
|
"darling",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 1.0.109",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "derive_builder_macro"
|
name = "derive_builder_macro"
|
||||||
version = "0.12.0"
|
version = "0.20.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e"
|
checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"derive_builder_core",
|
"derive_builder_core",
|
||||||
"syn 1.0.109",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -800,9 +800,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "either"
|
name = "either"
|
||||||
version = "1.10.0"
|
version = "1.11.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
|
checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "encode_unicode"
|
name = "encode_unicode"
|
||||||
@ -1018,7 +1018,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1423,7 +1423,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1476,9 +1476,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jobserver"
|
name = "jobserver"
|
||||||
version = "0.1.29"
|
version = "0.1.30"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f08474e32172238f2827bd160c67871cdb2801430f65c3979184dc362e3ca118"
|
checksum = "685a7d121ee3f65ae4fddd72b25a04bb36b6af81bc0828f7d5434c0fe60fa3a2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
@ -1703,7 +1703,7 @@ checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1775,9 +1775,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "monostate"
|
name = "monostate"
|
||||||
version = "0.1.11"
|
version = "0.1.12"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "878c2a1f1c70e5724fa28f101ca787b6a7e8ad5c5e4ae4ca3b0fa4a419fa9075"
|
checksum = "a20fffcd8ca4c69d31e036a71abc400147b41f90895df4edcb36497a1f8af8bf"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"monostate-impl",
|
"monostate-impl",
|
||||||
"serde",
|
"serde",
|
||||||
@ -1785,13 +1785,13 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "monostate-impl"
|
name = "monostate-impl"
|
||||||
version = "0.1.11"
|
version = "0.1.12"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f686d68a09079e63b1d2c64aa305095887ce50565f00a922ebfaeeee0d9ba6ce"
|
checksum = "bf307cbbbd777a9c10cec88ddafee572b3484caad5cce0c9236523c3803105a6"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1929,9 +1929,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num"
|
name = "num"
|
||||||
version = "0.4.1"
|
version = "0.4.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
|
checksum = "3135b08af27d103b0a51f2ae0f8632117b7b185ccf931445affa8df530576a41"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"num-bigint",
|
"num-bigint",
|
||||||
"num-complex",
|
"num-complex",
|
||||||
@ -1981,7 +1981,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2111,7 +2111,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2327,7 +2327,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2381,12 +2381,12 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prettyplease"
|
name = "prettyplease"
|
||||||
version = "0.2.17"
|
version = "0.2.19"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7"
|
checksum = "5ac2cf0f2e4f42b49f5ffd07dae8d746508ef7526c13940e5f524012ae6c6550"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2415,9 +2415,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro2"
|
name = "proc-macro2"
|
||||||
version = "1.0.79"
|
version = "1.0.81"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e"
|
checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
@ -2438,7 +2438,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
|
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2478,7 +2478,7 @@ dependencies = [
|
|||||||
"prost 0.12.4",
|
"prost 0.12.4",
|
||||||
"prost-types",
|
"prost-types",
|
||||||
"regex",
|
"regex",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -2505,7 +2505,7 @@ dependencies = [
|
|||||||
"itertools 0.12.1",
|
"itertools 0.12.1",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2752,12 +2752,6 @@ version = "0.6.29"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
|
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "regex-syntax"
|
|
||||||
version = "0.7.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex-syntax"
|
name = "regex-syntax"
|
||||||
version = "0.8.3"
|
version = "0.8.3"
|
||||||
@ -2864,7 +2858,7 @@ dependencies = [
|
|||||||
"quote",
|
"quote",
|
||||||
"rust-embed-utils",
|
"rust-embed-utils",
|
||||||
"shellexpand",
|
"shellexpand",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
"walkdir",
|
"walkdir",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -3038,29 +3032,29 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde"
|
name = "serde"
|
||||||
version = "1.0.197"
|
version = "1.0.198"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2"
|
checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_derive"
|
name = "serde_derive"
|
||||||
version = "1.0.197"
|
version = "1.0.198"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b"
|
checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_json"
|
name = "serde_json"
|
||||||
version = "1.0.115"
|
version = "1.0.116"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd"
|
checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"itoa",
|
"itoa",
|
||||||
"ryu",
|
"ryu",
|
||||||
@ -3270,7 +3264,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"rustversion",
|
"rustversion",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3292,9 +3286,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "2.0.58"
|
version = "2.0.60"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687"
|
checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@ -3399,7 +3393,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-benchmark"
|
name = "text-generation-benchmark"
|
||||||
version = "2.0.0"
|
version = "2.0.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"average",
|
"average",
|
||||||
"clap",
|
"clap",
|
||||||
@ -3412,7 +3406,7 @@ dependencies = [
|
|||||||
"tabled",
|
"tabled",
|
||||||
"text-generation-client",
|
"text-generation-client",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers 0.14.1",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
@ -3420,7 +3414,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "2.0.0"
|
version = "2.0.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"grpc-metadata",
|
"grpc-metadata",
|
||||||
@ -3436,7 +3430,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "2.0.0"
|
version = "2.0.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
@ -3454,7 +3448,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "2.0.0"
|
version = "2.0.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"axum",
|
"axum",
|
||||||
@ -3482,7 +3476,7 @@ dependencies = [
|
|||||||
"serde_json",
|
"serde_json",
|
||||||
"text-generation-client",
|
"text-generation-client",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers 0.15.2",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
@ -3511,7 +3505,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3585,46 +3579,11 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokenizers"
|
name = "tokenizers"
|
||||||
version = "0.14.1"
|
version = "0.19.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d9be88c795d8b9f9c4002b3a8f26a6d0876103a6f523b32ea3bac52d8560c17c"
|
checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aho-corasick",
|
"aho-corasick",
|
||||||
"clap",
|
|
||||||
"derive_builder",
|
|
||||||
"esaxx-rs",
|
|
||||||
"getrandom",
|
|
||||||
"hf-hub",
|
|
||||||
"indicatif",
|
|
||||||
"itertools 0.11.0",
|
|
||||||
"lazy_static",
|
|
||||||
"log",
|
|
||||||
"macro_rules_attribute",
|
|
||||||
"monostate",
|
|
||||||
"onig",
|
|
||||||
"paste",
|
|
||||||
"rand",
|
|
||||||
"rayon",
|
|
||||||
"rayon-cond",
|
|
||||||
"regex",
|
|
||||||
"regex-syntax 0.7.5",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"spm_precompiled",
|
|
||||||
"thiserror",
|
|
||||||
"unicode-normalization-alignments",
|
|
||||||
"unicode-segmentation",
|
|
||||||
"unicode_categories",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tokenizers"
|
|
||||||
version = "0.15.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "3dd47962b0ba36e7fd33518fbf1754d136fd1474000162bbf2a8b5fcb2d3654d"
|
|
||||||
dependencies = [
|
|
||||||
"aho-corasick",
|
|
||||||
"clap",
|
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
"esaxx-rs",
|
"esaxx-rs",
|
||||||
"getrandom",
|
"getrandom",
|
||||||
@ -3688,7 +3647,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3837,7 +3796,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"prost-build",
|
"prost-build",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3910,7 +3869,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -4151,7 +4110,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"regex",
|
"regex",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -4273,7 +4232,7 @@ dependencies = [
|
|||||||
"once_cell",
|
"once_cell",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -4307,7 +4266,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
"wasm-bindgen-backend",
|
"wasm-bindgen-backend",
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
@ -4391,7 +4350,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be"
|
checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-core",
|
"windows-core",
|
||||||
"windows-targets 0.52.4",
|
"windows-targets 0.52.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -4400,7 +4359,7 @@ version = "0.52.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
|
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-targets 0.52.4",
|
"windows-targets 0.52.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -4427,7 +4386,7 @@ version = "0.52.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-targets 0.52.4",
|
"windows-targets 0.52.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -4462,17 +4421,18 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-targets"
|
name = "windows-targets"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b"
|
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows_aarch64_gnullvm 0.52.4",
|
"windows_aarch64_gnullvm 0.52.5",
|
||||||
"windows_aarch64_msvc 0.52.4",
|
"windows_aarch64_msvc 0.52.5",
|
||||||
"windows_i686_gnu 0.52.4",
|
"windows_i686_gnu 0.52.5",
|
||||||
"windows_i686_msvc 0.52.4",
|
"windows_i686_gnullvm",
|
||||||
"windows_x86_64_gnu 0.52.4",
|
"windows_i686_msvc 0.52.5",
|
||||||
"windows_x86_64_gnullvm 0.52.4",
|
"windows_x86_64_gnu 0.52.5",
|
||||||
"windows_x86_64_msvc 0.52.4",
|
"windows_x86_64_gnullvm 0.52.5",
|
||||||
|
"windows_x86_64_msvc 0.52.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -4489,9 +4449,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_aarch64_gnullvm"
|
name = "windows_aarch64_gnullvm"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9"
|
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_aarch64_msvc"
|
name = "windows_aarch64_msvc"
|
||||||
@ -4507,9 +4467,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_aarch64_msvc"
|
name = "windows_aarch64_msvc"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675"
|
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_gnu"
|
name = "windows_i686_gnu"
|
||||||
@ -4525,9 +4485,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_gnu"
|
name = "windows_i686_gnu"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3"
|
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_gnullvm"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_msvc"
|
name = "windows_i686_msvc"
|
||||||
@ -4543,9 +4509,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_msvc"
|
name = "windows_i686_msvc"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02"
|
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnu"
|
name = "windows_x86_64_gnu"
|
||||||
@ -4561,9 +4527,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnu"
|
name = "windows_x86_64_gnu"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03"
|
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnullvm"
|
name = "windows_x86_64_gnullvm"
|
||||||
@ -4579,9 +4545,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnullvm"
|
name = "windows_x86_64_gnullvm"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177"
|
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_msvc"
|
name = "windows_x86_64_msvc"
|
||||||
@ -4597,9 +4563,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_msvc"
|
name = "windows_x86_64_msvc"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
|
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "winnow"
|
name = "winnow"
|
||||||
@ -4637,7 +4603,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -9,11 +9,15 @@ members = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "2.0.0"
|
version = "2.0.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
|
||||||
|
[workspace.dependencies]
|
||||||
|
tokenizers = { version = "0.19.1", features = ["http"] }
|
||||||
|
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
debug = 1
|
debug = 1
|
||||||
incremental = true
|
incremental = true
|
||||||
|
@ -23,9 +23,9 @@ serde_json = "1.0"
|
|||||||
tabled = "0.14.0"
|
tabled = "0.14.0"
|
||||||
text-generation-client = { path = "../router/client" }
|
text-generation-client = { path = "../router/client" }
|
||||||
thiserror = "1.0.48"
|
thiserror = "1.0.48"
|
||||||
tokenizers = { version = "0.14.0", features = ["http"] }
|
tokenizers = { workspace = true }
|
||||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
||||||
tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]}
|
tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]}
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
hf-hub = "0.3.1"
|
hf-hub = { workspace = true }
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
},
|
},
|
||||||
"version": "2.0.0"
|
"version": "2.0.1"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/": {
|
"/": {
|
||||||
@ -408,9 +408,14 @@
|
|||||||
},
|
},
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
"description": "Generated Text",
|
"description": "Generated Chat Completion",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ChatCompletion"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"text/event-stream": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/components/schemas/ChatCompletionChunk"
|
"$ref": "#/components/schemas/ChatCompletionChunk"
|
||||||
}
|
}
|
||||||
@ -492,11 +497,16 @@
|
|||||||
},
|
},
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
"description": "Generated Text",
|
"description": "Generated Chat Completion",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/components/schemas/ChatCompletionChunk"
|
"$ref": "#/components/schemas/Completion"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"text/event-stream": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/CompletionCompleteChunk"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -930,7 +940,7 @@
|
|||||||
"tool_prompt": {
|
"tool_prompt": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "A prompt to be appended before the tools",
|
"description": "A prompt to be appended before the tools",
|
||||||
"example": "\"Based on the conversation, please choose the most appropriate tool to use: \"",
|
"example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"",
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
"tools": {
|
"tools": {
|
||||||
@ -1071,7 +1081,10 @@
|
|||||||
"example": "mistralai/Mistral-7B-Instruct-v0.2"
|
"example": "mistralai/Mistral-7B-Instruct-v0.2"
|
||||||
},
|
},
|
||||||
"prompt": {
|
"prompt": {
|
||||||
"type": "string",
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
"description": "The prompt to generate completions for.",
|
"description": "The prompt to generate completions for.",
|
||||||
"example": "What is Deep Learning?"
|
"example": "What is Deep Learning?"
|
||||||
},
|
},
|
||||||
@ -1234,17 +1247,17 @@
|
|||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
"name",
|
"name",
|
||||||
"parameters"
|
"arguments"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"arguments": {},
|
||||||
"description": {
|
"description": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
"name": {
|
"name": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
}
|
||||||
"parameters": {}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"GenerateParameters": {
|
"GenerateParameters": {
|
||||||
@ -1260,7 +1273,7 @@
|
|||||||
},
|
},
|
||||||
"decoder_input_details": {
|
"decoder_input_details": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"default": "true"
|
"default": "false"
|
||||||
},
|
},
|
||||||
"details": {
|
"details": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
@ -1285,6 +1298,7 @@
|
|||||||
"$ref": "#/components/schemas/GrammarType"
|
"$ref": "#/components/schemas/GrammarType"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
"default": "null",
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
"max_new_tokens": {
|
"max_new_tokens": {
|
||||||
@ -1478,6 +1492,7 @@
|
|||||||
"max_batch_total_tokens",
|
"max_batch_total_tokens",
|
||||||
"max_waiting_tokens",
|
"max_waiting_tokens",
|
||||||
"validation_workers",
|
"validation_workers",
|
||||||
|
"max_client_batch_size",
|
||||||
"version"
|
"version"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
@ -1503,6 +1518,11 @@
|
|||||||
"example": "2",
|
"example": "2",
|
||||||
"minimum": 0
|
"minimum": 0
|
||||||
},
|
},
|
||||||
|
"max_client_batch_size": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": "32",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
"max_concurrent_requests": {
|
"max_concurrent_requests": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Router Parameters",
|
"description": "Router Parameters",
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
# Guidance
|
# Guidance
|
||||||
|
|
||||||
Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developer guide LLM responses to fit their needs.
|
Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developers guide LLM responses to fit their needs.
|
||||||
|
|
||||||
These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
|
These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library. The tool support is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
|
||||||
|
|
||||||
|
> The Grammar guidance support is currently only available in the TGI API due to lack of support in Open AI API.
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
@ -16,7 +18,7 @@ If you're not up to date, grab the latest version and let's get started!
|
|||||||
|
|
||||||
- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.
|
- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.
|
||||||
- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.
|
- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.
|
||||||
- [JSON Schema Integration](#json-schema-integration): Fine grain control over your requests via JSON schema.
|
- [JSON Schema Integration](#json-schema-integration): Fine-grained control over your requests via JSON schema.
|
||||||
- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.
|
- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.
|
||||||
|
|
||||||
### Tools and Functions
|
### Tools and Functions
|
||||||
@ -72,9 +74,9 @@ curl localhost:3000/generate \
|
|||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
|
A grammar can be defined using Pydantic models, JSON schemas, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
|
||||||
|
|
||||||
> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
|
> Note: A grammar must compile to an intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
|
||||||
|
|
||||||
### Constrain with Pydantic
|
### Constrain with Pydantic
|
||||||
|
|
||||||
@ -151,7 +153,7 @@ json_schema = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"inputs": "[INST]convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park [/INST]",
|
"inputs": "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"max_new_tokens": 200,
|
"max_new_tokens": 200,
|
||||||
"repetition_penalty": 1.3,
|
"repetition_penalty": 1.3,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
## Speculation
|
## Speculation
|
||||||
|
|
||||||
|
|
||||||
Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea.
|
Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea.
|
||||||
The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid.
|
The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid.
|
||||||
|
|
||||||
@ -36,7 +37,7 @@ In order to use medusa models in TGI, simply point to a medusa enabled model, an
|
|||||||
|
|
||||||
|
|
||||||
If you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`.
|
If you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`.
|
||||||
Ngram works by trying to find in the previous sequence existing tokens that match, and use those as speculation.
|
N-gram works by trying to find matching tokens in the previous sequence, and use those as speculation for generating new tokens. For example, if the tokens "np.mean" appear multiple times in the sequence, the model can speculate that the next continuation of the tokens "np." is probably also "mean".
|
||||||
|
|
||||||
This is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much.
|
This is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much.
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ Token streaming is the mode in which the server returns the tokens one by one as
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
With token streaming, the server can start returning the tokens one by one before having to generate the whole response. Users can have a sense of the generation's quality earlier than the end of the generation. This has different positive effects:
|
With token streaming, the server can start returning the tokens one by one before having to generate the whole response. Users can have a sense of the generation's quality before the end of the generation. This has different positive effects:
|
||||||
|
|
||||||
* Users can get results orders of magnitude earlier for extremely long queries.
|
* Users can get results orders of magnitude earlier for extremely long queries.
|
||||||
* Seeing something in progress allows users to stop the generation if it's not going in the direction they expect.
|
* Seeing something in progress allows users to stop the generation if it's not going in the direction they expect.
|
||||||
@ -116,7 +116,7 @@ curl -N 127.0.0.1:8080/generate_stream \
|
|||||||
First, we need to install the `@huggingface/inference` library.
|
First, we need to install the `@huggingface/inference` library.
|
||||||
`npm install @huggingface/inference`
|
`npm install @huggingface/inference`
|
||||||
|
|
||||||
If you're using the free Inference API, you can use `HfInference`. If you're using inference endpoints, you can use `HfInferenceEndpoint`. Let's
|
If you're using the free Inference API, you can use `HfInference`. If you're using inference endpoints, you can use `HfInferenceEndpoint`.
|
||||||
|
|
||||||
We can create a `HfInferenceEndpoint` providing our endpoint URL and credential.
|
We can create a `HfInferenceEndpoint` providing our endpoint URL and credential.
|
||||||
|
|
||||||
|
@ -18,8 +18,8 @@ Text Generation Inference implements many optimizations and features, such as:
|
|||||||
- Logits warper (temperature scaling, top-p, top-k, repetition penalty)
|
- Logits warper (temperature scaling, top-p, top-k, repetition penalty)
|
||||||
- Stop sequences
|
- Stop sequences
|
||||||
- Log probabilities
|
- Log probabilities
|
||||||
- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output.
|
|
||||||
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance.
|
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance.
|
||||||
|
- [Guidance](../conceptual/guidance): Enable function calling and tool-use by forcing the model to generate structured outputs based on your own predefined output schemas.
|
||||||
|
|
||||||
Text Generation Inference is used in production by multiple projects, such as:
|
Text Generation Inference is used in production by multiple projects, such as:
|
||||||
|
|
||||||
|
@ -293,6 +293,7 @@ def launcher(event_loop):
|
|||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
max_input_length: Optional[int] = None,
|
max_input_length: Optional[int] = None,
|
||||||
|
max_batch_prefill_tokens: Optional[int] = None,
|
||||||
max_total_tokens: Optional[int] = None,
|
max_total_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
@ -334,6 +335,9 @@ def launcher(event_loop):
|
|||||||
if max_input_length:
|
if max_input_length:
|
||||||
args.append("--max-input-length")
|
args.append("--max-input-length")
|
||||||
args.append(str(max_input_length))
|
args.append(str(max_input_length))
|
||||||
|
if max_batch_prefill_tokens:
|
||||||
|
args.append("--max-batch-prefill-tokens")
|
||||||
|
args.append(str(max_batch_prefill_tokens))
|
||||||
if max_total_tokens:
|
if max_total_tokens:
|
||||||
args.append("--max-total-tokens")
|
args.append("--max-total-tokens")
|
||||||
args.append(str(max_total_tokens))
|
args.append(str(max_total_tokens))
|
||||||
@ -371,6 +375,7 @@ def launcher(event_loop):
|
|||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
max_input_length: Optional[int] = None,
|
max_input_length: Optional[int] = None,
|
||||||
|
max_batch_prefill_tokens: Optional[int] = None,
|
||||||
max_total_tokens: Optional[int] = None,
|
max_total_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
@ -395,6 +400,9 @@ def launcher(event_loop):
|
|||||||
if max_input_length:
|
if max_input_length:
|
||||||
args.append("--max-input-length")
|
args.append("--max-input-length")
|
||||||
args.append(str(max_input_length))
|
args.append(str(max_input_length))
|
||||||
|
if max_batch_prefill_tokens:
|
||||||
|
args.append("--max-batch-prefill-tokens")
|
||||||
|
args.append(str(max_batch_prefill_tokens))
|
||||||
if max_total_tokens:
|
if max_total_tokens:
|
||||||
args.append("--max-total-tokens")
|
args.append("--max-total-tokens")
|
||||||
args.append(str(max_total_tokens))
|
args.append(str(max_total_tokens))
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native",
|
"system_fingerprint": "2.0.1-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 100,
|
"completion_tokens": 100,
|
||||||
"prompt_tokens": 60,
|
"prompt_tokens": 60,
|
||||||
|
@ -29,7 +29,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native",
|
"system_fingerprint": "2.0.1-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 36,
|
"completion_tokens": 36,
|
||||||
"prompt_tokens": 8,
|
"prompt_tokens": 8,
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -27,7 +27,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -42,7 +42,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -57,7 +57,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -72,7 +72,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -87,7 +87,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -102,7 +102,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -117,7 +117,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -132,7 +132,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -147,7 +147,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -162,7 +162,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -177,7 +177,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -192,7 +192,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -207,7 +207,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -222,7 +222,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -237,7 +237,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -252,7 +252,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -267,7 +267,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -282,7 +282,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -297,7 +297,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -312,7 +312,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -327,7 +327,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -342,7 +342,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -357,7 +357,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -372,7 +372,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -387,7 +387,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -402,7 +402,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -417,7 +417,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -432,7 +432,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -447,7 +447,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -462,7 +462,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -477,7 +477,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -492,7 +492,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -507,7 +507,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -522,7 +522,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -537,7 +537,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -552,7 +552,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -567,7 +567,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -582,7 +582,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -597,6 +597,6 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -11,7 +11,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native",
|
"system_fingerprint": "2.0.1-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 5,
|
"completion_tokens": 5,
|
||||||
"prompt_tokens": 6,
|
"prompt_tokens": 6,
|
||||||
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -8.5625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -10.78125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 288,
|
||||||
|
"logprob": -0.2854004,
|
||||||
|
"special": false,
|
||||||
|
"text": "ing"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.37573242,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 633,
|
||||||
|
"logprob": -0.09301758,
|
||||||
|
"special": false,
|
||||||
|
"text": " new"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4480,
|
||||||
|
"logprob": -0.3322754,
|
||||||
|
"special": false,
|
||||||
|
"text": " feature"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 297,
|
||||||
|
"logprob": -0.8510742,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 272,
|
||||||
|
"logprob": -0.13464355,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2039,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " game"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28723,
|
||||||
|
"logprob": -0.89990234,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test requesting a new feature in the game.\n\n"
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,73 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 330,
|
||||||
|
"logprob": -0.13000488,
|
||||||
|
"special": false,
|
||||||
|
"text": " A"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13088,
|
||||||
|
"logprob": -0.6713867,
|
||||||
|
"special": false,
|
||||||
|
"text": " chicken"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.2980957,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6398,
|
||||||
|
"logprob": -0.060638428,
|
||||||
|
"special": false,
|
||||||
|
"text": " sitting"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 356,
|
||||||
|
"logprob": -0.27319336,
|
||||||
|
"special": false,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.140625,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 17972,
|
||||||
|
"logprob": -0.040405273,
|
||||||
|
"special": false,
|
||||||
|
"text": " pile"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 302,
|
||||||
|
"logprob": -0.0002708435,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2445,
|
||||||
|
"logprob": -0.095336914,
|
||||||
|
"special": false,
|
||||||
|
"text": " money"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28723,
|
||||||
|
"logprob": -0.0068359375,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " A chicken is sitting on a pile of money."
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -30,7 +30,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native",
|
"system_fingerprint": "2.0.1-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 37,
|
"completion_tokens": 37,
|
||||||
"prompt_tokens": 524,
|
"prompt_tokens": 524,
|
||||||
|
@ -30,7 +30,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native",
|
"system_fingerprint": "2.0.1-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 37,
|
"completion_tokens": 37,
|
||||||
"prompt_tokens": 524,
|
"prompt_tokens": 524,
|
||||||
|
@ -30,7 +30,7 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native",
|
"system_fingerprint": "2.0.1-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 48,
|
"completion_tokens": 48,
|
||||||
"prompt_tokens": 320,
|
"prompt_tokens": 320,
|
||||||
|
@ -23,5 +23,5 @@
|
|||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native"
|
"system_fingerprint": "2.0.1-native"
|
||||||
}
|
}
|
||||||
|
81
integration-tests/models/test_idefics2.py
Normal file
81
integration-tests/models/test_idefics2.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import pytest
|
||||||
|
import base64
|
||||||
|
|
||||||
|
|
||||||
|
# TODO fix the server parsser to count inline image tokens correctly
|
||||||
|
def get_chicken():
|
||||||
|
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_idefics2_next_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"HuggingFaceM4/idefics2-8b",
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_idefics2_next(flash_idefics2_next_handle):
|
||||||
|
await flash_idefics2_next_handle.health(300)
|
||||||
|
return flash_idefics2_next_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
|
||||||
|
chicken = get_chicken()
|
||||||
|
response = await flash_idefics2_next.generate(
|
||||||
|
f"User:Write me a short story<end_of_utterance> \nAssistant:",
|
||||||
|
max_new_tokens=10,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
response.generated_text == " A chicken is sitting on a pile of money."
|
||||||
|
), f"{repr(response.generated_text)}"
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):
|
||||||
|
response = await flash_idefics2_next.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_idefics2_next_load(
|
||||||
|
flash_idefics2_next, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
chicken = get_chicken()
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_idefics2_next,
|
||||||
|
f"User:Write me a short story<end_of_utterance> \nAssistant:",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
generated_texts = [r.generated_text for r in responses]
|
||||||
|
assert generated_texts[0] == " A chicken is sitting on a pile of money."
|
||||||
|
assert len(generated_texts) == 4
|
||||||
|
assert all([r.generated_text == generated_texts[0] for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation-integration-tests"
|
name = "text-generation-integration-tests"
|
||||||
version = "2.0.0"
|
version = "2.0.1"
|
||||||
description = "Text Generation Inference integration tests"
|
description = "Text Generation Inference integration tests"
|
||||||
authors = ["Nicolas Patry <nicolas@huggingface.co>"]
|
authors = ["Nicolas Patry <nicolas@huggingface.co>"]
|
||||||
|
|
||||||
|
@ -1,71 +1,94 @@
|
|||||||
import { check, randomSeed } from 'k6';
|
import { check } from 'k6';
|
||||||
|
import { scenario } from 'k6/execution';
|
||||||
import http from 'k6/http';
|
import http from 'k6/http';
|
||||||
import { Trend, Counter } from 'k6/metrics';
|
import { Trend, Counter } from 'k6/metrics';
|
||||||
import { randomItem } from 'https://jslib.k6.io/k6-utils/1.2.0/index.js';
|
|
||||||
|
|
||||||
const seed = 0;
|
const host = __ENV.HOST;
|
||||||
|
const model_id = __ENV.MODEL_ID;
|
||||||
const host = __ENV.HOST || '127.0.0.1:8000';
|
|
||||||
const timePerToken = new Trend('time_per_token', true);
|
const timePerToken = new Trend('time_per_token', true);
|
||||||
const tokens = new Counter('tokens');
|
const tokens = new Counter('tokens');
|
||||||
const new_tokens = new Counter('new_tokens');
|
const new_tokens = new Counter('new_tokens');
|
||||||
const input_tokens = new Counter('input_tokens');
|
const input_tokens = new Counter('input_tokens');
|
||||||
|
const max_new_tokens = 50;
|
||||||
|
|
||||||
randomSeed(seed);
|
|
||||||
// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json"))
|
// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json"))
|
||||||
const shareGPT = JSON.parse(open("small.json"))
|
const shareGPT = JSON.parse(open("small.json"))
|
||||||
|
|
||||||
|
|
||||||
export function get_options(reference_latency_ms){
|
export function get_options() {
|
||||||
return {
|
return {
|
||||||
thresholds: {
|
thresholds: {
|
||||||
http_req_failed: ['rate==0'],
|
http_req_failed: ['rate==0'],
|
||||||
time_per_token: [{
|
// time_per_token: [{
|
||||||
threshold: `p(50)<${5 * reference_latency_ms}`,
|
// threshold: `p(50)<${5 * reference_latency_ms}`,
|
||||||
abortOnFail: true,
|
// abortOnFail: true,
|
||||||
delayAbortEval: '10s'
|
// delayAbortEval: '10s'
|
||||||
}],
|
// }],
|
||||||
},
|
},
|
||||||
scenarios: {
|
scenarios: {
|
||||||
load_test: {
|
single_user: {
|
||||||
executor: 'constant-arrival-rate',
|
executor: 'constant-arrival-rate',
|
||||||
duration: '60s',
|
duration: '60s',
|
||||||
preAllocatedVUs: 10,
|
preAllocatedVUs: 1,
|
||||||
rate: 10,
|
rate: 1,
|
||||||
timeUnit: '1s',
|
timeUnit: '1s',
|
||||||
},
|
},
|
||||||
|
// load_test: {
|
||||||
|
// executor: 'constant-arrival-rate',
|
||||||
|
// duration: '60s',
|
||||||
|
// preAllocatedVUs: 100,
|
||||||
|
// rate: 1,
|
||||||
|
// timeUnit: '1s',
|
||||||
|
// },
|
||||||
|
// breakpoint: {
|
||||||
|
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows
|
||||||
|
// preAllocatedVUs: 1000,
|
||||||
|
// stages: [
|
||||||
|
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
|
||||||
|
// ],
|
||||||
|
// },
|
||||||
|
// throughput: {
|
||||||
|
// executor: 'shared-iterations',
|
||||||
|
// vus: 100,
|
||||||
|
// iterations: 200,
|
||||||
|
// maxDuration: '40s',
|
||||||
|
// },
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function generate_payload(gpt, max_new_tokens) {
|
||||||
|
const input = gpt["conversations"][0]["value"];
|
||||||
|
return { "messages": [{ "role": "user", "content": input }], "temperature": 0, "model": `${model_id}`, "max_tokens": max_new_tokens }
|
||||||
|
}
|
||||||
|
|
||||||
export function run(host, generate_payload, max_new_tokens) {
|
export const options = get_options();
|
||||||
const headers = {'Content-Type': 'application/json'};
|
|
||||||
const query = randomItem(shareGPT);
|
export default function run() {
|
||||||
const payload = JSON.stringify(generate_payload(query));
|
const headers = { 'Content-Type': 'application/json' };
|
||||||
const res = http.post(`http://${host}/generate`, payload, {
|
const query = shareGPT[scenario.iterationInTest % shareGPT.length];
|
||||||
|
const payload = JSON.stringify(generate_payload(query, max_new_tokens));
|
||||||
|
const res = http.post(`http://${host}/v1/chat/completions`, payload, {
|
||||||
headers,
|
headers,
|
||||||
});
|
});
|
||||||
if(res.status >= 400 && res.status < 500){
|
if (res.status >= 400 && res.status < 500) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
check(res, {
|
check(res, {
|
||||||
'Post status is 200': (r) => res.status === 200,
|
'Post status is 200': (res) => res.status === 200,
|
||||||
});
|
});
|
||||||
const duration = res.timings.duration;
|
const duration = res.timings.duration;
|
||||||
|
|
||||||
if (res.status === 200) {
|
if (res.status === 200) {
|
||||||
const body = res.json();
|
const body = res.json();
|
||||||
const n_tokens = body.details.tokens.length;
|
const completion_tokens = body.usage.completion_tokens;
|
||||||
const latency_ms_per_token = duration / n_tokens;
|
const latency_ms_per_token = duration / completion_tokens;
|
||||||
timePerToken.add(latency_ms_per_token);
|
timePerToken.add(latency_ms_per_token);
|
||||||
const latency_in_s = latency_ms_per_token / 1000;
|
const prompt_tokens = body.usage.prompt_tokens;
|
||||||
const individual_throughput = 1 / latency_in_s;
|
input_tokens.add(prompt_tokens);
|
||||||
const _input_tokens = body.details.prefill.length;
|
new_tokens.add(completion_tokens);
|
||||||
tokens.add(n_tokens + _input_tokens);
|
tokens.add(completion_tokens + prompt_tokens);
|
||||||
input_tokens.add(_input_tokens);
|
|
||||||
new_tokens.add(n_tokens);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
import { get_options, run } from "./common.js";
|
|
||||||
|
|
||||||
const reference_latency_ms = 70;
|
|
||||||
const host = __ENV.HOST || '127.0.0.1:8000';
|
|
||||||
const max_new_tokens = 50;
|
|
||||||
|
|
||||||
|
|
||||||
function generate_payload(gpt){
|
|
||||||
const input = gpt["conversations"][0]["value"];
|
|
||||||
return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "decoder_input_details": true}}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const options = get_options(reference_latency_ms);
|
|
||||||
|
|
||||||
export default function(){
|
|
||||||
run(host, generate_payload, max_new_tokens);
|
|
||||||
}
|
|
@ -1,17 +0,0 @@
|
|||||||
import { get_options, run } from "./common.js";
|
|
||||||
|
|
||||||
const reference_latency_ms = 22;
|
|
||||||
const host = __ENV.HOST || '127.0.0.1:8000';
|
|
||||||
const max_new_tokens = 50;
|
|
||||||
|
|
||||||
|
|
||||||
function generate_payload(gpt){
|
|
||||||
const input = gpt["conversations"][0]["value"];
|
|
||||||
return {"prompt": input, "temperature": 0.5, "ignore_eos": true}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const options = get_options(reference_latency_ms);
|
|
||||||
|
|
||||||
export default function(){
|
|
||||||
run(host, generate_payload, max_new_tokens);
|
|
||||||
}
|
|
@ -21,7 +21,7 @@ axum-tracing-opentelemetry = "0.14.1"
|
|||||||
text-generation-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
futures = "0.3.28"
|
futures = "0.3.28"
|
||||||
hf-hub = { version = "0.3.0", features = ["tokio"] }
|
hf-hub = { workspace = true }
|
||||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||||
metrics = "0.21.1"
|
metrics = "0.21.1"
|
||||||
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
||||||
@ -33,7 +33,7 @@ reqwest = { version = "0.11.20", features = [] }
|
|||||||
serde = "1.0.188"
|
serde = "1.0.188"
|
||||||
serde_json = "1.0.107"
|
serde_json = "1.0.107"
|
||||||
thiserror = "1.0.48"
|
thiserror = "1.0.48"
|
||||||
tokenizers = { version = "0.15.1", features = ["http"] }
|
tokenizers = { workspace = true}
|
||||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.14"
|
tokio-stream = "0.1.14"
|
||||||
tower-http = { version = "0.4.4", features = ["cors"] }
|
tower-http = { version = "0.4.4", features = ["cors"] }
|
||||||
|
@ -114,8 +114,12 @@ impl Client {
|
|||||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||||
|
|
||||||
let mut inputs = String::new();
|
let mut inputs = String::new();
|
||||||
inputs.push_str(";
|
|
||||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||||
|
if n_tokens == 0 {
|
||||||
|
// 1 request is enough to test vision heads.
|
||||||
|
// Sending images on other queries messes up easily with truncation.
|
||||||
|
inputs.push_str("");
|
||||||
|
}
|
||||||
|
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
|
@ -57,6 +57,31 @@ fn select_best_resolution(
|
|||||||
best_fit.unwrap_or((original_height, original_width))
|
best_fit.unwrap_or((original_height, original_width))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_unpadded_features(
|
||||||
|
height: usize,
|
||||||
|
width: usize,
|
||||||
|
npatches: usize,
|
||||||
|
num_patch_height: usize,
|
||||||
|
num_patch_width: usize,
|
||||||
|
) -> (usize, usize) {
|
||||||
|
let current_height = npatches * num_patch_height;
|
||||||
|
let current_width = npatches * num_patch_width;
|
||||||
|
|
||||||
|
let aspect_ratio: f64 = width as f64 / height as f64;
|
||||||
|
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
|
||||||
|
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
|
||||||
|
let new_height = (height * current_width) / width;
|
||||||
|
(new_height, current_width)
|
||||||
|
} else {
|
||||||
|
let new_width = (width * current_height) / height;
|
||||||
|
(current_height, new_width)
|
||||||
|
};
|
||||||
|
|
||||||
|
let unpadded_features = current_height * current_width;
|
||||||
|
let newline_features = current_height;
|
||||||
|
(unpadded_features, newline_features)
|
||||||
|
}
|
||||||
|
|
||||||
impl LlavaNext {
|
impl LlavaNext {
|
||||||
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
||||||
let image_size = self.vision_config.image_size;
|
let image_size = self.vision_config.image_size;
|
||||||
@ -65,11 +90,9 @@ impl LlavaNext {
|
|||||||
let npatches = image_size / patch_size;
|
let npatches = image_size / patch_size;
|
||||||
let (num_patch_height, num_patch_width) =
|
let (num_patch_height, num_patch_width) =
|
||||||
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
||||||
// Ceil
|
|
||||||
let height_of_patch = (height * npatches + width - 1) / width;
|
let (unpadded_features, newline_features) =
|
||||||
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
|
get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);
|
||||||
// They are only added after width
|
|
||||||
let newline_features = height_of_patch * num_patch_width;
|
|
||||||
// The base patch covers the entire image
|
// The base patch covers the entire image
|
||||||
let base_features = npatches.pow(2);
|
let base_features = npatches.pow(2);
|
||||||
unpadded_features + newline_features + base_features
|
unpadded_features + newline_features + base_features
|
||||||
@ -84,6 +107,17 @@ pub struct ClipVisionModel {
|
|||||||
patch_size: usize,
|
patch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "model_type")]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct Idefics2 {}
|
||||||
|
|
||||||
|
impl Idefics2 {
|
||||||
|
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
||||||
|
320
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(tag = "model_type")]
|
#[serde(tag = "model_type")]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
@ -92,6 +126,7 @@ pub enum Config {
|
|||||||
ClipVisionModel(ClipVisionModel),
|
ClipVisionModel(ClipVisionModel),
|
||||||
Mistral,
|
Mistral,
|
||||||
Idefics,
|
Idefics,
|
||||||
|
Idefics2(Idefics2),
|
||||||
Ssm,
|
Ssm,
|
||||||
GptBigcode,
|
GptBigcode,
|
||||||
Santacoder,
|
Santacoder,
|
||||||
@ -146,13 +181,17 @@ mod test {
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let slots = config.get_number_of_features(20, 20);
|
||||||
|
assert_eq!(slots, 1176);
|
||||||
let slots = config.get_number_of_features(640, 640);
|
let slots = config.get_number_of_features(640, 640);
|
||||||
assert_eq!(slots, 2928);
|
assert_eq!(slots, 2928);
|
||||||
let slots = config.get_number_of_features(480, 640);
|
let slots = config.get_number_of_features(480, 640);
|
||||||
assert_eq!(slots, 2340);
|
assert_eq!(slots, 2340);
|
||||||
let slots = config.get_number_of_features(899, 1024);
|
let slots = config.get_number_of_features(899, 1024);
|
||||||
assert_eq!(slots, 2732);
|
assert_eq!(slots, 2634);
|
||||||
let slots = config.get_number_of_features(1024, 899);
|
let slots = config.get_number_of_features(1024, 899);
|
||||||
assert_eq!(slots, 3320);
|
assert_eq!(slots, 2640);
|
||||||
|
let slots = config.get_number_of_features(1067, 1600);
|
||||||
|
assert_eq!(slots, 2144);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -73,9 +73,9 @@ pub struct HubTokenizerConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl HubTokenizerConfig {
|
impl HubTokenizerConfig {
|
||||||
pub fn from_file(filename: &std::path::Path) -> Self {
|
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
||||||
let content = std::fs::read_to_string(filename).unwrap();
|
let content = std::fs::read_to_string(filename).ok()?;
|
||||||
serde_json::from_str(&content).unwrap_or_default()
|
serde_json::from_str(&content).ok()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,6 +116,7 @@ mod token_serde {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Value::Null => Ok(None),
|
||||||
_ => Err(de::Error::custom("invalid token format")),
|
_ => Err(de::Error::custom("invalid token format")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -168,9 +169,12 @@ pub struct Info {
|
|||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
|
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
|
||||||
pub(crate) struct GenerateParameters {
|
pub(crate) struct GenerateParameters {
|
||||||
|
/// Generate best_of sequences and return the one if the highest token logprobs.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
|
||||||
pub best_of: Option<usize>,
|
pub best_of: Option<usize>,
|
||||||
|
|
||||||
|
/// The value used to module the logits distribution.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = 0.0,
|
exclusive_minimum = 0.0,
|
||||||
@ -179,6 +183,9 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = 0.5
|
example = 0.5
|
||||||
)]
|
)]
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// The parameter for repetition penalty. 1.0 means no penalty.
|
||||||
|
/// See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = 0.0,
|
exclusive_minimum = 0.0,
|
||||||
@ -187,6 +194,10 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = 1.03
|
example = 1.03
|
||||||
)]
|
)]
|
||||||
pub repetition_penalty: Option<f32>,
|
pub repetition_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// The parameter for frequency penalty. 1.0 means no penalty
|
||||||
|
/// Penalize new tokens based on their existing frequency in the text so far,
|
||||||
|
/// decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = -2.0,
|
exclusive_minimum = -2.0,
|
||||||
@ -195,9 +206,13 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = 0.1
|
example = 0.1
|
||||||
)]
|
)]
|
||||||
pub frequency_penalty: Option<f32>,
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
|
||||||
pub top_k: Option<i32>,
|
pub top_k: Option<i32>,
|
||||||
|
|
||||||
|
/// Top-p value for nucleus sampling.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = 0.0,
|
exclusive_minimum = 0.0,
|
||||||
@ -207,6 +222,9 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = 0.95
|
example = 0.95
|
||||||
)]
|
)]
|
||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
|
|
||||||
|
/// Typical Decoding mass
|
||||||
|
/// See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = 0.0,
|
exclusive_minimum = 0.0,
|
||||||
@ -216,30 +234,48 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = 0.95
|
example = 0.95
|
||||||
)]
|
)]
|
||||||
pub typical_p: Option<f32>,
|
pub typical_p: Option<f32>,
|
||||||
|
|
||||||
|
/// Activate logits sampling.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "false", example = true)]
|
#[schema(default = "false", example = true)]
|
||||||
pub do_sample: bool,
|
pub do_sample: bool,
|
||||||
|
|
||||||
|
/// Maximum number of tokens to generate.
|
||||||
#[serde(default = "default_max_new_tokens")]
|
#[serde(default = "default_max_new_tokens")]
|
||||||
#[schema(nullable = true, default = "100", example = "20")]
|
#[schema(nullable = true, default = "100", example = "20")]
|
||||||
pub max_new_tokens: Option<u32>,
|
pub max_new_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// Whether to prepend the prompt to the generated text
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = false)]
|
#[schema(nullable = true, default = "null", example = false)]
|
||||||
pub return_full_text: Option<bool>,
|
pub return_full_text: Option<bool>,
|
||||||
|
|
||||||
|
/// Stop generating tokens if a member of `stop` is generated.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
|
|
||||||
|
/// Truncate inputs tokens to the given size.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub truncate: Option<usize>,
|
pub truncate: Option<usize>,
|
||||||
|
|
||||||
|
/// Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "false", example = true)]
|
#[schema(default = "false", example = true)]
|
||||||
pub watermark: bool,
|
pub watermark: bool,
|
||||||
|
|
||||||
|
/// Whether to return generation details.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "true")]
|
#[schema(default = "true")]
|
||||||
pub details: bool,
|
pub details: bool,
|
||||||
|
|
||||||
|
/// Whether to return decoder input token logprobs and ids.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "true")]
|
#[schema(default = "false")]
|
||||||
pub decoder_input_details: bool,
|
pub decoder_input_details: bool,
|
||||||
|
|
||||||
|
/// Random sampling seed.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = 0,
|
exclusive_minimum = 0,
|
||||||
@ -248,10 +284,15 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = "null"
|
example = "null"
|
||||||
)]
|
)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
|
|
||||||
|
/// The number of highest probability vocabulary tokens to keep for top-n-filtering.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||||
pub top_n_tokens: Option<u32>,
|
pub top_n_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// Grammar constraints for the generation.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub grammar: Option<GrammarType>,
|
pub grammar: Option<GrammarType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -548,7 +589,9 @@ pub(crate) struct ChatCompletionChoice {
|
|||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionDelta {
|
pub(crate) struct ChatCompletionDelta {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
// TODO Modify this to a true enum.
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub role: Option<String>,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "What is Deep Learning?")]
|
#[schema(example = "What is Deep Learning?")]
|
||||||
pub content: Option<String>,
|
pub content: Option<String>,
|
||||||
@ -582,6 +625,31 @@ impl ChatCompletionChunk {
|
|||||||
logprobs: Option<ChatCompletionLogprobs>,
|
logprobs: Option<ChatCompletionLogprobs>,
|
||||||
finish_reason: Option<String>,
|
finish_reason: Option<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let delta = match (delta, tool_calls) {
|
||||||
|
(Some(delta), _) => ChatCompletionDelta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: Some(delta),
|
||||||
|
tool_calls: None,
|
||||||
|
},
|
||||||
|
(None, Some(tool_calls)) => ChatCompletionDelta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: None,
|
||||||
|
tool_calls: Some(DeltaToolCall {
|
||||||
|
index: 0,
|
||||||
|
id: String::new(),
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: Function {
|
||||||
|
name: None,
|
||||||
|
arguments: tool_calls[0].to_string(),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
(None, None) => ChatCompletionDelta {
|
||||||
|
role: None,
|
||||||
|
content: None,
|
||||||
|
tool_calls: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
Self {
|
Self {
|
||||||
id: String::new(),
|
id: String::new(),
|
||||||
object: "text_completion".to_string(),
|
object: "text_completion".to_string(),
|
||||||
@ -590,19 +658,7 @@ impl ChatCompletionChunk {
|
|||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
choices: vec![ChatCompletionChoice {
|
choices: vec![ChatCompletionChoice {
|
||||||
index: 0,
|
index: 0,
|
||||||
delta: ChatCompletionDelta {
|
delta,
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: delta,
|
|
||||||
tool_calls: tool_calls.map(|tc| DeltaToolCall {
|
|
||||||
index: 0,
|
|
||||||
id: String::new(),
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: Function {
|
|
||||||
name: None,
|
|
||||||
arguments: tc[0].to_string(),
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
logprobs,
|
logprobs,
|
||||||
finish_reason,
|
finish_reason,
|
||||||
}],
|
}],
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use axum::http::HeaderValue;
|
use axum::http::HeaderValue;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||||
use hf_hub::{Repo, RepoType};
|
use hf_hub::{Cache, Repo, RepoType};
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||||
use opentelemetry::sdk::trace;
|
use opentelemetry::sdk::trace;
|
||||||
use opentelemetry::sdk::trace::Sampler;
|
use opentelemetry::sdk::trace::Sampler;
|
||||||
@ -11,7 +11,7 @@ use opentelemetry_otlp::WithExportConfig;
|
|||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::BufReader;
|
use std::io::BufReader;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::{Path, PathBuf};
|
||||||
use text_generation_client::{ClientError, ShardedClient};
|
use text_generation_client::{ClientError, ShardedClient};
|
||||||
use text_generation_router::config::Config;
|
use text_generation_router::config::Config;
|
||||||
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
||||||
@ -162,7 +162,6 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
// Tokenizer instance
|
// Tokenizer instance
|
||||||
// This will only be used to validate payloads
|
// This will only be used to validate payloads
|
||||||
let local_path = Path::new(&tokenizer_name);
|
let local_path = Path::new(&tokenizer_name);
|
||||||
let local_model = local_path.exists() && local_path.is_dir();
|
|
||||||
|
|
||||||
// Shared API builder initialization
|
// Shared API builder initialization
|
||||||
let api_builder = || {
|
let api_builder = || {
|
||||||
@ -181,109 +180,113 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||||
|
|
||||||
// Initialize API if needed
|
// Initialize API if needed
|
||||||
|
#[derive(Clone)]
|
||||||
|
enum Type {
|
||||||
|
Api(Api),
|
||||||
|
Cache(Cache),
|
||||||
|
None,
|
||||||
|
}
|
||||||
let api = if use_api {
|
let api = if use_api {
|
||||||
tracing::info!("Using the Hugging Face API");
|
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
||||||
match api_builder().build() {
|
let cache = Cache::default();
|
||||||
Ok(api) => Some(api),
|
tracing::warn!("Offline mode active using cache defaults");
|
||||||
Err(_) => {
|
Type::Cache(cache)
|
||||||
tracing::warn!("Unable to build the Hugging Face API");
|
} else {
|
||||||
None
|
tracing::info!("Using the Hugging Face API");
|
||||||
|
match api_builder().build() {
|
||||||
|
Ok(api) => Type::Api(api),
|
||||||
|
Err(_) => {
|
||||||
|
tracing::warn!("Unable to build the Hugging Face API");
|
||||||
|
Type::None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
None
|
Type::None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Load tokenizer and model info
|
// Load tokenizer and model info
|
||||||
let (tokenizer, model_info, config) = if local_model {
|
let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
|
||||||
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
|
Type::None => (
|
||||||
let model_info = HubModelInfo {
|
Some(local_path.join("tokenizer.json")),
|
||||||
model_id: tokenizer_name.to_string(),
|
Some(local_path.join("config.json")),
|
||||||
sha: None,
|
Some(local_path.join("tokenizer_config.json")),
|
||||||
pipeline_tag: None,
|
None,
|
||||||
};
|
),
|
||||||
let config: Option<Config> = std::fs::read_to_string(local_path.join("config.json"))
|
Type::Api(api) => {
|
||||||
.ok()
|
let api_repo = api.repo(Repo::with_revision(
|
||||||
.as_ref()
|
tokenizer_name.to_string(),
|
||||||
.and_then(|c| serde_json::from_str(c).ok());
|
RepoType::Model,
|
||||||
|
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||||
|
));
|
||||||
|
|
||||||
(tokenizer, model_info, config)
|
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
||||||
} else if let Some(api) = api.clone() {
|
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
||||||
let api_repo = api.repo(Repo::with_revision(
|
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||||
tokenizer_name.to_string(),
|
};
|
||||||
RepoType::Model,
|
let config_filename = api_repo.get("config.json").await.ok();
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||||
));
|
|
||||||
|
|
||||||
let tokenizer = match api_repo.get("tokenizer.json").await {
|
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
||||||
Ok(tokenizer_filename) => Tokenizer::from_file(tokenizer_filename).ok(),
|
Some(model_info)
|
||||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
} else {
|
||||||
};
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
|
None
|
||||||
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
|
};
|
||||||
std::fs::read_to_string(filename)
|
(
|
||||||
.ok()
|
tokenizer_filename,
|
||||||
.as_ref()
|
config_filename,
|
||||||
.and_then(|c| {
|
tokenizer_config_filename,
|
||||||
let config: Result<Config, _> = serde_json::from_str(c);
|
model_info,
|
||||||
if let Err(err) = &config {
|
)
|
||||||
tracing::warn!("Could not parse config {err:?}");
|
}
|
||||||
}
|
Type::Cache(cache) => {
|
||||||
config.ok()
|
let repo = cache.repo(Repo::with_revision(
|
||||||
})
|
tokenizer_name.to_string(),
|
||||||
});
|
RepoType::Model,
|
||||||
|
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||||
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
|
));
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
(
|
||||||
HubModelInfo {
|
repo.get("tokenizer.json"),
|
||||||
model_id: tokenizer_name.to_string(),
|
repo.get("config.json"),
|
||||||
sha: None,
|
repo.get("tokenizer_config.json"),
|
||||||
pipeline_tag: None,
|
None,
|
||||||
}
|
)
|
||||||
});
|
|
||||||
|
|
||||||
(tokenizer, model_info, config)
|
|
||||||
} else {
|
|
||||||
// No API and no local model
|
|
||||||
return Err(RouterError::ArgumentValidation(
|
|
||||||
"No local model found and no revision specified".to_string(),
|
|
||||||
));
|
|
||||||
};
|
|
||||||
|
|
||||||
tracing::info!("Using config {config:?}");
|
|
||||||
|
|
||||||
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
|
||||||
let tokenizer_config = if let Some(path) = tokenizer_config_path {
|
|
||||||
tracing::info!("Using local tokenizer config from user specified path");
|
|
||||||
HubTokenizerConfig::from_file(&std::path::PathBuf::from(path))
|
|
||||||
} else if local_model {
|
|
||||||
tracing::info!("Using local tokenizer config");
|
|
||||||
HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json"))
|
|
||||||
} else {
|
|
||||||
match api {
|
|
||||||
Some(api) => {
|
|
||||||
tracing::info!("Using the Hugging Face API to retrieve tokenizer config");
|
|
||||||
let repo = Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.unwrap_or("main".to_string()),
|
|
||||||
);
|
|
||||||
get_tokenizer_config(&api.repo(repo))
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|| {
|
|
||||||
tracing::warn!(
|
|
||||||
"Could not retrieve tokenizer config from the Hugging Face hub."
|
|
||||||
);
|
|
||||||
HubTokenizerConfig::default()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
|
||||||
HubTokenizerConfig::default()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let tokenizer: Option<Tokenizer> =
|
||||||
|
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
|
||||||
|
let config: Option<Config> = config_filename.and_then(|filename| {
|
||||||
|
std::fs::read_to_string(filename)
|
||||||
|
.ok()
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|c| {
|
||||||
|
let config: Result<Config, _> = serde_json::from_str(c);
|
||||||
|
if let Err(err) = &config {
|
||||||
|
tracing::warn!("Could not parse config {err:?}");
|
||||||
|
}
|
||||||
|
config.ok()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
|
||||||
|
model_id: tokenizer_name.to_string(),
|
||||||
|
sha: None,
|
||||||
|
pipeline_tag: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
|
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||||
|
{
|
||||||
|
HubTokenizerConfig::from_file(filename)
|
||||||
|
} else {
|
||||||
|
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||||
|
};
|
||||||
|
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||||
|
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||||
|
HubTokenizerConfig::default()
|
||||||
|
});
|
||||||
|
|
||||||
|
tracing::info!("Using config {config:?}");
|
||||||
if tokenizer.is_none() {
|
if tokenizer.is_none() {
|
||||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
||||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
tracing::warn!("Rust input length validation and truncation is disabled");
|
||||||
@ -480,7 +483,7 @@ pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// get base tokenizer
|
/// get base tokenizer
|
||||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokenizer> {
|
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
||||||
let config_filename = api_repo.get("config.json").await.ok()?;
|
let config_filename = api_repo.get("config.json").await.ok()?;
|
||||||
|
|
||||||
// Open the file in read-only mode with buffer.
|
// Open the file in read-only mode with buffer.
|
||||||
@ -497,8 +500,7 @@ pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<Tokeniz
|
|||||||
"main".to_string(),
|
"main".to_string(),
|
||||||
));
|
));
|
||||||
|
|
||||||
let tokenizer_filename = api_base_repo.get("tokenizer.json").await.ok()?;
|
api_base_repo.get("tokenizer.json").await.ok()
|
||||||
Tokenizer::from_file(tokenizer_filename).ok()
|
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
@ -1000,6 +1000,7 @@ async fn chat_completions(
|
|||||||
tools,
|
tools,
|
||||||
tool_choice,
|
tool_choice,
|
||||||
tool_prompt,
|
tool_prompt,
|
||||||
|
temperature,
|
||||||
..
|
..
|
||||||
} = req;
|
} = req;
|
||||||
|
|
||||||
@ -1008,6 +1009,11 @@ async fn chat_completions(
|
|||||||
let logprobs = logprobs.unwrap_or(false);
|
let logprobs = logprobs.unwrap_or(false);
|
||||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||||
let stop = stop.unwrap_or_default();
|
let stop = stop.unwrap_or_default();
|
||||||
|
// enable greedy only when temperature is 0
|
||||||
|
let (do_sample, temperature) = match temperature {
|
||||||
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
|
other => (true, other),
|
||||||
|
};
|
||||||
|
|
||||||
// extract tool grammar if present
|
// extract tool grammar if present
|
||||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||||
@ -1054,13 +1060,13 @@ async fn chat_completions(
|
|||||||
inputs: inputs.to_string(),
|
inputs: inputs.to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature: req.temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
frequency_penalty: req.frequency_penalty,
|
frequency_penalty: req.frequency_penalty,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: req.top_p,
|
top_p: req.top_p,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
do_sample: true,
|
do_sample,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop,
|
stop,
|
||||||
@ -1097,7 +1103,13 @@ async fn chat_completions(
|
|||||||
let (content, tool_calls) = if tool_grammar.is_some() {
|
let (content, tool_calls) = if tool_grammar.is_some() {
|
||||||
(None, Some(vec![stream_token.token.text]))
|
(None, Some(vec![stream_token.token.text]))
|
||||||
} else {
|
} else {
|
||||||
(Some(stream_token.token.text), None)
|
let content = if !stream_token.token.special {
|
||||||
|
Some(stream_token.token.text)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
(content, None)
|
||||||
};
|
};
|
||||||
|
|
||||||
event
|
event
|
||||||
|
@ -540,7 +540,57 @@ fn prepare_input(
|
|||||||
inputs = modified_inputs;
|
inputs = modified_inputs;
|
||||||
tokenizer_query
|
tokenizer_query
|
||||||
}
|
}
|
||||||
Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(),
|
Some(Config::Idefics2(config)) => {
|
||||||
|
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||||
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
|
let mut start = 0;
|
||||||
|
for chunk in RE.find_iter(&inputs) {
|
||||||
|
let chunk_start = chunk.start();
|
||||||
|
let chunk_end = chunk.end();
|
||||||
|
if chunk_start != start {
|
||||||
|
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||||
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
|
}
|
||||||
|
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
|
let slots = config.get_number_of_features(height, width);
|
||||||
|
tokenizer_query.push_str("<fake_token_around_image>");
|
||||||
|
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||||
|
tokenizer_query.push_str("<fake_token_around_image>");
|
||||||
|
|
||||||
|
modified_inputs.push_str(&image_uri);
|
||||||
|
start = chunk_end;
|
||||||
|
}
|
||||||
|
if start != inputs.len() - 1 {
|
||||||
|
modified_inputs.push_str(&inputs[start..]);
|
||||||
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
|
}
|
||||||
|
inputs = modified_inputs;
|
||||||
|
tokenizer_query
|
||||||
|
}
|
||||||
|
Some(Config::Idefics) => {
|
||||||
|
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||||
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
|
let mut start = 0;
|
||||||
|
for chunk in RE.find_iter(&inputs) {
|
||||||
|
let chunk_start = chunk.start();
|
||||||
|
let chunk_end = chunk.end();
|
||||||
|
if chunk_start != start {
|
||||||
|
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||||
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
|
}
|
||||||
|
let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
|
let slots = 1;
|
||||||
|
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||||
|
modified_inputs.push_str(&image_uri);
|
||||||
|
start = chunk_end;
|
||||||
|
}
|
||||||
|
if start != inputs.len() - 1 {
|
||||||
|
modified_inputs.push_str(&inputs[start..]);
|
||||||
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
|
}
|
||||||
|
inputs = modified_inputs;
|
||||||
|
tokenizer_query
|
||||||
|
}
|
||||||
_ => inputs.clone(),
|
_ => inputs.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
1081
server/poetry.lock
generated
1081
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation-server"
|
name = "text-generation-server"
|
||||||
version = "2.0.0"
|
version = "2.0.1"
|
||||||
description = "Text Generation Inference Python gRPC Server"
|
description = "Text Generation Inference Python gRPC Server"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
||||||
@ -24,9 +24,9 @@ opentelemetry-exporter-otlp = "^1.15.0"
|
|||||||
opentelemetry-instrumentation-grpc = "^0.36b0"
|
opentelemetry-instrumentation-grpc = "^0.36b0"
|
||||||
hf-transfer = "^0.1.2"
|
hf-transfer = "^0.1.2"
|
||||||
sentencepiece = "^0.1.97"
|
sentencepiece = "^0.1.97"
|
||||||
tokenizers = "^0.15.0"
|
tokenizers = "^0.19.1"
|
||||||
huggingface-hub = "^0.19.3"
|
huggingface-hub = "^0.19.3"
|
||||||
transformers = "^4.39"
|
transformers = "^4.40"
|
||||||
einops = "^0.6.1"
|
einops = "^0.6.1"
|
||||||
texttable = { version = "^1.6.7", optional = true }
|
texttable = { version = "^1.6.7", optional = true }
|
||||||
datasets = { version = "^2.14.0", optional = true }
|
datasets = { version = "^2.14.0", optional = true }
|
||||||
|
@ -5,7 +5,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -14,7 +14,7 @@ grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -30,15 +30,15 @@ packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.2 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==69.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.39.3 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -5,7 +5,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -14,7 +14,7 @@ grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -30,15 +30,15 @@ packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.2 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==69.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.39.3 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -68,6 +68,7 @@ try:
|
|||||||
)
|
)
|
||||||
from text_generation_server.models.idefics import IDEFICSSharded
|
from text_generation_server.models.idefics import IDEFICSSharded
|
||||||
from text_generation_server.models.llava_next import LlavaNext
|
from text_generation_server.models.llava_next import LlavaNext
|
||||||
|
from text_generation_server.models.idefics2 import Idefics2
|
||||||
from text_generation_server.models.flash_mistral import FlashMistral
|
from text_generation_server.models.flash_mistral import FlashMistral
|
||||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||||
from text_generation_server.models.flash_phi import FlashPhi
|
from text_generation_server.models.flash_phi import FlashPhi
|
||||||
@ -327,7 +328,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "llama" or model_type == "baichuan":
|
elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashLlama(
|
return FlashLlama(
|
||||||
model_id,
|
model_id,
|
||||||
@ -579,6 +580,18 @@ def get_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
if model_type == "idefics2":
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
return Idefics2(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
|
||||||
if model_type == "llava_next":
|
if model_type == "llava_next":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
@ -45,58 +45,6 @@ if IS_ROCM_SYSTEM:
|
|||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
class LlamaConfig(PretrainedConfig):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size=32000,
|
|
||||||
hidden_size=4096,
|
|
||||||
intermediate_size=11008,
|
|
||||||
num_hidden_layers=32,
|
|
||||||
num_attention_heads=32,
|
|
||||||
num_key_value_heads=None,
|
|
||||||
hidden_act="silu",
|
|
||||||
max_position_embeddings=2048,
|
|
||||||
initializer_range=0.02,
|
|
||||||
rms_norm_eps=1e-6,
|
|
||||||
use_cache=True,
|
|
||||||
pad_token_id=0,
|
|
||||||
bos_token_id=1,
|
|
||||||
eos_token_id=2,
|
|
||||||
pretraining_tp=1,
|
|
||||||
tie_word_embeddings=False,
|
|
||||||
rope_scaling=None,
|
|
||||||
rope_theta=10000.0,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
|
|
||||||
# for backward compatibility
|
|
||||||
if num_key_value_heads is None:
|
|
||||||
num_key_value_heads = num_attention_heads
|
|
||||||
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.rms_norm_eps = rms_norm_eps
|
|
||||||
self.pretraining_tp = pretraining_tp
|
|
||||||
self.use_cache = use_cache
|
|
||||||
self.rope_scaling = rope_scaling
|
|
||||||
self.rope_theta = rope_theta
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
bos_token_id=bos_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
@ -108,6 +56,13 @@ def load_attention(config, prefix, weights):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
elif config.model_type == "phi3":
|
||||||
|
return TensorParallelColumnLinear.load_qkv(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return TensorParallelColumnLinear.load_multi(
|
return TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
@ -265,13 +220,21 @@ class LlamaMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
if config.model_type == "phi3":
|
||||||
config,
|
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
config,
|
||||||
weights=weights,
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
dim=0,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
|
@ -409,23 +409,29 @@ class MistralModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashMistralForCausalLM(torch.nn.Module):
|
class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, name=None):
|
||||||
|
if name is None:
|
||||||
|
name = "model"
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix=(
|
prefix=(
|
||||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
f"{name}.embed_tokens"
|
||||||
|
if not prefix
|
||||||
|
else f"{prefix}.{name}.embed_tokens"
|
||||||
),
|
),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.model = MistralModel(
|
self.model = MistralModel(
|
||||||
prefix="model" if not prefix else f"{prefix}.model",
|
prefix=name if not prefix else f"{prefix}.{name}",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
# TODO dirty hack for idefics2.
|
||||||
|
prefix=(
|
||||||
|
"lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
|
||||||
|
),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window
|
||||||
|
829
server/text_generation_server/models/custom_modeling/idefics2.py
Normal file
829
server/text_generation_server/models/custom_modeling/idefics2.py
Normal file
@ -0,0 +1,829 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" PyTorch Idefics2 model."""
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
|
load_text_model,
|
||||||
|
load_vision_model,
|
||||||
|
)
|
||||||
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
|
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||||
|
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||||
|
)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2VisionEmbeddings(nn.Module):
|
||||||
|
"""
|
||||||
|
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
|
||||||
|
resolution.
|
||||||
|
|
||||||
|
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
|
||||||
|
which allows treating images in their native aspect ratio and without the need to resize them to the same
|
||||||
|
fixed size. In particular, we start from the original pre-trained SigLIP model
|
||||||
|
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.image_size = config.image_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=self.embed_dim,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size,
|
||||||
|
padding="valid",
|
||||||
|
)
|
||||||
|
self.patch_embedding.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.patch_embedding.bias = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_patches_per_side = self.image_size // self.patch_size
|
||||||
|
self.num_patches = self.num_patches_per_side**2
|
||||||
|
self.num_positions = self.num_patches
|
||||||
|
self.position_embedding = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.position_embedding", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
||||||
|
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values)
|
||||||
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
max_nb_patches_h, max_nb_patches_w = (
|
||||||
|
max_im_h // self.patch_size,
|
||||||
|
max_im_w // self.patch_size,
|
||||||
|
)
|
||||||
|
boundaries = torch.arange(
|
||||||
|
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
|
||||||
|
)
|
||||||
|
position_ids = torch.full(
|
||||||
|
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
||||||
|
nb_patches_h = p_attn_mask[:, 0].sum()
|
||||||
|
nb_patches_w = p_attn_mask[0].sum()
|
||||||
|
|
||||||
|
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
||||||
|
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
||||||
|
|
||||||
|
bucket_coords_h = torch.bucketize(
|
||||||
|
fractional_coords_h, boundaries, right=True
|
||||||
|
)
|
||||||
|
bucket_coords_w = torch.bucketize(
|
||||||
|
fractional_coords_w, boundaries, right=True
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_ids = (
|
||||||
|
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
|
||||||
|
).flatten()
|
||||||
|
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
||||||
|
|
||||||
|
position_ids = position_ids.to(self.position_embedding.weight.device)
|
||||||
|
embeddings = embeddings + self.position_embedding(position_ids)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2VisionAttention(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_size = self.embed_dim // self.num_heads
|
||||||
|
if self.head_size * self.num_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
|
f" {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.scale = self.head_size**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||||
|
|
||||||
|
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.out_proj = TensorParallelRowLinear.load(
|
||||||
|
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
qkv = self.qkv(hidden_states)
|
||||||
|
query_states, key_states, value_states = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
k_v_seq_len = key_states.shape[-2]
|
||||||
|
attn_weights = (
|
||||||
|
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(
|
||||||
|
attn_weights, p=self.dropout, training=self.training
|
||||||
|
)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2VisionMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.fc1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2EncoderLayer(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.self_attn = Idefics2VisionAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.layer_norm1 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
|
||||||
|
)
|
||||||
|
self.layer_norm2 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = Idefics2VisionMLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2Encoder(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Idefics2EncoderLayer(
|
||||||
|
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||||
|
)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for encoder_layer in self.layers:
|
||||||
|
hidden_states = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2VisionTransformer(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embeddings = Idefics2VisionEmbeddings(
|
||||||
|
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.encoder = Idefics2Encoder(
|
||||||
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.post_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values,
|
||||||
|
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
):
|
||||||
|
batch_size = pixel_values.size(0)
|
||||||
|
if patch_attention_mask is None:
|
||||||
|
patch_size = self.config.patch_size
|
||||||
|
patch_attention_mask = torch.ones(
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
pixel_values.size(2) // patch_size,
|
||||||
|
pixel_values.size(3) // patch_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
patch_attention_mask = patch_attention_mask.to(
|
||||||
|
dtype=torch.bool, device=pixel_values.device
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.embeddings(
|
||||||
|
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
||||||
|
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
||||||
|
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
||||||
|
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
||||||
|
if not torch.any(~patch_attention_mask):
|
||||||
|
patch_attention_mask = None
|
||||||
|
else:
|
||||||
|
patch_attention_mask = _prepare_4d_attention_mask(
|
||||||
|
patch_attention_mask, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
inputs_embeds=hidden_states,
|
||||||
|
attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_hidden_state = encoder_outputs
|
||||||
|
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||||
|
|
||||||
|
return last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2MLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
act = config.text_config.hidden_act
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
start_shape = hidden_states.shape[:-1]
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
intermediate_size = gate_up_states.shape[-1] // 2
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, intermediate_size)
|
||||||
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
|
||||||
|
).view(*start_shape, -1)
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2RMSNorm(nn.Module):
|
||||||
|
def __init__(self, prefix, weights, eps):
|
||||||
|
"""
|
||||||
|
Idefics2RMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2PerceiverAttention(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layer_idx = None
|
||||||
|
self.hidden_size = config.text_config.hidden_size
|
||||||
|
self.num_heads = config.perceiver_config.resampler_n_heads
|
||||||
|
self.head_size = config.perceiver_config.resampler_head_dim
|
||||||
|
self.num_key_value_heads = config.perceiver_config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.attention_dropout = config.perceiver_config.attention_dropout
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
self.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.kv = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = latents.size()
|
||||||
|
kv_seq_len = q_len + context.size()[1]
|
||||||
|
|
||||||
|
hidden_states = torch.concat([context, latents], dim=-2)
|
||||||
|
query_states = self.q_proj(latents)
|
||||||
|
kv = self.kv(hidden_states)
|
||||||
|
key_states, value_states = kv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_key_value_heads,
|
||||||
|
self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(
|
||||||
|
query_states, key_states.transpose(2, 3)
|
||||||
|
) / math.sqrt(self.head_size)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2PerceiverLayer(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.text_config.hidden_size
|
||||||
|
self.n_latents = config.perceiver_config.resampler_n_latents
|
||||||
|
self.depth = config.perceiver_config.resampler_depth
|
||||||
|
self.rms_norm_eps = config.text_config.rms_norm_eps
|
||||||
|
|
||||||
|
self.input_latents_norm = Idefics2RMSNorm(
|
||||||
|
prefix=f"{prefix}.input_latents_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=self.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.input_context_norm = Idefics2RMSNorm(
|
||||||
|
prefix=f"{prefix}.input_context_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=self.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.self_attn = Idefics2PerceiverAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = Idefics2RMSNorm(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=self.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||||
|
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||||
|
"""
|
||||||
|
residual = latents
|
||||||
|
|
||||||
|
latents = self.input_latents_norm(latents)
|
||||||
|
context = self.input_context_norm(context)
|
||||||
|
|
||||||
|
latents = self.self_attn(
|
||||||
|
latents=latents,
|
||||||
|
context=context,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
latents = residual + latents
|
||||||
|
residual = latents
|
||||||
|
|
||||||
|
latents = self.post_attention_layernorm(latents)
|
||||||
|
latents = self.mlp(latents)
|
||||||
|
latents = residual + latents
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2PerceiverResampler(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.text_config.hidden_size
|
||||||
|
self.hidden_act = config.perceiver_config.hidden_act
|
||||||
|
self.n_latents = config.perceiver_config.resampler_n_latents
|
||||||
|
self.depth = config.perceiver_config.resampler_depth
|
||||||
|
self.rms_norm_eps = config.text_config.rms_norm_eps
|
||||||
|
|
||||||
|
# Create Latents for Perceiver
|
||||||
|
self.latents = weights.get_tensor(f"{prefix}.latents")
|
||||||
|
|
||||||
|
# Create Transformer Blocks
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Idefics2PerceiverLayer(
|
||||||
|
prefix=f"{prefix}.layers.{idx}", config=config, weights=weights
|
||||||
|
)
|
||||||
|
for idx in range(self.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = Idefics2RMSNorm(
|
||||||
|
prefix=f"{prefix}.norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.text_config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attention_mask,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# seq embed -> bsz seq embed
|
||||||
|
latents = self.latents.unsqueeze(0).expand(
|
||||||
|
(context.shape[0], *self.latents.size())
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_attention_mask = torch.ones(
|
||||||
|
(attention_mask.size(0), latents.size(1)),
|
||||||
|
dtype=attention_mask.dtype,
|
||||||
|
device=attention_mask.device,
|
||||||
|
)
|
||||||
|
attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
|
||||||
|
attention_mask = _prepare_4d_attention_mask(
|
||||||
|
attention_mask, latents.dtype, tgt_len=self.n_latents
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed_context = latents
|
||||||
|
for perceiver_layer in self.layers:
|
||||||
|
compressed_context = perceiver_layer(
|
||||||
|
compressed_context,
|
||||||
|
context,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
compressed_context = self.norm(compressed_context)
|
||||||
|
|
||||||
|
return compressed_context
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2Connector(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.modality_projection = Idefics2MLP(
|
||||||
|
prefix=f"{prefix}.modality_projection", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.perceiver_resampler = Idefics2PerceiverResampler(
|
||||||
|
prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image_hidden_states, attention_mask):
|
||||||
|
image_hidden_states = self.modality_projection(image_hidden_states)
|
||||||
|
image_hidden_states = self.perceiver_resampler(
|
||||||
|
context=image_hidden_states, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
return image_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2ForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
config.vision_config.quantize = config.quantize
|
||||||
|
config.vision_config.use_medusa = config.use_medusa
|
||||||
|
config.text_config.quantize = config.quantize
|
||||||
|
config.text_config.use_medusa = config.use_medusa
|
||||||
|
|
||||||
|
vision_config = config.vision_config
|
||||||
|
self.text_model = load_text_model(
|
||||||
|
prefix="model" if not prefix else f"{prefix}.model",
|
||||||
|
config=config.text_config,
|
||||||
|
weights=weights,
|
||||||
|
name="text_model",
|
||||||
|
)
|
||||||
|
self.dtype = weights.dtype
|
||||||
|
self.vision_model = Idefics2VisionTransformer(
|
||||||
|
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
|
||||||
|
config=vision_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.connector = Idefics2Connector(
|
||||||
|
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
self.image_seq_len = config.perceiver_config.resampler_n_latents
|
||||||
|
self.image_token_id = config.image_token_id
|
||||||
|
self.pad_token_id = (
|
||||||
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
def _merge_input_ids_with_image_features(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
image_features: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
|
# mask = input_ids == self.config.image_token_index
|
||||||
|
mask = input_ids == self.config.image_token_id
|
||||||
|
# Let's pray we have enabled enough slots !
|
||||||
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
# Unused here
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is not None:
|
||||||
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
|
all_states = []
|
||||||
|
all_pixel_values = pixel_values
|
||||||
|
all_pixel_mask = pixel_attention_mask
|
||||||
|
for i in range(batch_size):
|
||||||
|
pixel_values = all_pixel_values.to(
|
||||||
|
dtype=self.dtype
|
||||||
|
) # fp16 compatibility
|
||||||
|
pixel_values = pixel_values[i : i + 1]
|
||||||
|
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
|
# Remove padding images - padding images are full 0.
|
||||||
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
|
real_images_inds = (pixel_values == 0.0).sum(
|
||||||
|
dim=(-1, -2, -3)
|
||||||
|
) != nb_values_per_image
|
||||||
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
|
||||||
|
# Handle the vision attention mask
|
||||||
|
if pixel_attention_mask is None:
|
||||||
|
pixel_attention_mask = torch.ones(
|
||||||
|
size=(
|
||||||
|
pixel_values.size(0),
|
||||||
|
pixel_values.size(2),
|
||||||
|
pixel_values.size(3),
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask/pP p
|
||||||
|
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(
|
||||||
|
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||||
|
)
|
||||||
|
pixel_attention_mask = pixel_attention_mask[
|
||||||
|
real_images_inds
|
||||||
|
].contiguous()
|
||||||
|
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(
|
||||||
|
dimension=1, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patches_subgrid = patches_subgrid.unfold(
|
||||||
|
dimension=2, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
|
||||||
|
# Get sequence from the vision encoder
|
||||||
|
image_hidden_states = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
patch_attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modality projection & resampling
|
||||||
|
image_hidden_states = self.connector(
|
||||||
|
image_hidden_states,
|
||||||
|
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
||||||
|
)
|
||||||
|
all_states.append(image_hidden_states)
|
||||||
|
image_hidden_states = torch.stack(all_states, dim=0)
|
||||||
|
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||||
|
# that simply don't exist
|
||||||
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
|
input_ids, inputs_embeds, image_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.text_model.model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
true_max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -23,6 +23,10 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
|
||||||
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
|
load_text_model,
|
||||||
|
load_vision_model,
|
||||||
|
)
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -105,36 +109,6 @@ class LlavaNextMultiModalProjector(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def load_vision_model(prefix, config, weights):
|
|
||||||
if config.model_type == "clip_vision_model":
|
|
||||||
from text_generation_server.models.custom_modeling.clip import (
|
|
||||||
CLIPVisionTransformer,
|
|
||||||
)
|
|
||||||
|
|
||||||
return CLIPVisionTransformer(
|
|
||||||
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
|
||||||
|
|
||||||
|
|
||||||
def load_text_model(prefix, config, weights):
|
|
||||||
if config.model_type == "llama":
|
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
|
||||||
FlashLlamaForCausalLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
return FlashLlamaForCausalLM(prefix, config, weights)
|
|
||||||
elif config.model_type == "mistral":
|
|
||||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
|
||||||
FlashMistralForCausalLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
return FlashMistralForCausalLM(prefix, config, weights)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextForConditionalGeneration(nn.Module):
|
class LlavaNextForConditionalGeneration(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -180,7 +154,12 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
mask = input_ids == self.config.image_token_index
|
mask = input_ids == self.config.image_token_index
|
||||||
# Let's pray we have enabled enough slots !
|
# Let's pray we have enabled enough slots !
|
||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
try:
|
||||||
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
|
||||||
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -196,6 +175,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
# Unused for this model
|
||||||
|
pixel_attention_mask=None,
|
||||||
image_sizes: Optional[torch.LongTensor] = None,
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
||||||
|
28
server/text_generation_server/models/custom_modeling/vlm.py
Normal file
28
server/text_generation_server/models/custom_modeling/vlm.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
def load_text_model(prefix, config, weights, name=None):
|
||||||
|
if config.model_type == "llama":
|
||||||
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
|
FlashLlamaForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashLlamaForCausalLM(prefix, config, weights)
|
||||||
|
elif config.model_type == "mistral":
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
|
FlashMistralForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashMistralForCausalLM(prefix, config, weights, name=name)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_vision_model(prefix, config, weights):
|
||||||
|
if config.model_type == "clip_vision_model":
|
||||||
|
from text_generation_server.models.custom_modeling.clip import (
|
||||||
|
CLIPVisionTransformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CLIPVisionTransformer(
|
||||||
|
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
@ -2,14 +2,13 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
||||||
from transformers.models.llama import LlamaTokenizer
|
from transformers.models.llama import LlamaTokenizer
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
LlamaConfig,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
@ -53,8 +52,17 @@ class FlashLlama(FlashCausalLM):
|
|||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
generation_config = GenerationConfig.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
if isinstance(generation_config.eos_token_id, (list, set)):
|
||||||
|
# TODO Huge hack
|
||||||
|
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
config = LlamaConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
@ -511,18 +511,33 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
logits, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
if cu_seqlen_prefill is None:
|
||||||
position_ids=position_ids,
|
logits, speculative_logits = self.compiled_model(
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
input_ids=input_ids,
|
||||||
kv_cache=kv_cache,
|
position_ids=position_ids,
|
||||||
block_tables=block_tables,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
slots=slots,
|
kv_cache=kv_cache,
|
||||||
input_lengths=input_lengths,
|
block_tables=block_tables,
|
||||||
max_s=max_s,
|
slots=slots,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
input_lengths=input_lengths,
|
||||||
lm_head_indices=lm_head_indices,
|
max_s=max_s,
|
||||||
)
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits, speculative_logits = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
51
server/text_generation_server/models/idefics2.py
Normal file
51
server/text_generation_server/models/idefics2.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoProcessor,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||||
|
Idefics2ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics2(VlmCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
|
# VRAM usage.
|
||||||
|
size={"longest_edge": 448, "shortest_edge": 378},
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
model_cls=Idefics2ForConditionalGeneration,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||||
|
return (
|
||||||
|
len(model.text_model.model.layers),
|
||||||
|
model.text_model.model.num_key_value_heads,
|
||||||
|
model.text_model.model.head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_past(self) -> Optional[int]:
|
||||||
|
return getattr(self.model.text_model, "max_past", None)
|
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
@ -34,3 +34,13 @@ class LlavaNext(VlmCausalLM):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||||
|
return (
|
||||||
|
len(model.language_model.model.layers),
|
||||||
|
model.language_model.model.num_key_value_heads,
|
||||||
|
model.language_model.model.head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_past(self) -> Optional[int]:
|
||||||
|
return getattr(self.model.language_model, "max_past", None)
|
||||||
|
@ -27,7 +27,14 @@ class Model(ABC):
|
|||||||
):
|
):
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
# all_special_ids is not set correctly if the rust tokenizer is unpacked
|
||||||
|
# TODO report this to transformers.
|
||||||
|
other_special_ids = {
|
||||||
|
id for id, token in tokenizer.added_tokens_decoder.items() if token.special
|
||||||
|
}
|
||||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
|
self.all_special_ids.update(other_special_ids)
|
||||||
self.requires_padding = requires_padding
|
self.requires_padding = requires_padding
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|||||||
return height // patch_size, width // patch_size
|
return height // patch_size, width // patch_size
|
||||||
|
|
||||||
|
|
||||||
|
def image_text_replacement(image_input, config, image_id) -> str:
|
||||||
|
if config.model_type == "idefics2":
|
||||||
|
# TODO technically depends on image splitting which is not implemented.
|
||||||
|
num_features = 320
|
||||||
|
return (
|
||||||
|
"<fake_token_around_image>"
|
||||||
|
+ "<image>" * num_features
|
||||||
|
+ "<fake_token_around_image>"
|
||||||
|
)
|
||||||
|
elif config.model_type == "llava_next":
|
||||||
|
height, width = image_input["image_sizes"][image_id]
|
||||||
|
num_features = get_number_of_features(height, width, config)
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
|
||||||
|
return "<image>" * num_features
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
|
|
||||||
|
def get_unpadded_features(
|
||||||
|
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
current_height = npatches * num_patch_height
|
||||||
|
current_width = npatches * num_patch_width
|
||||||
|
|
||||||
|
aspect_ratio: float = width / height
|
||||||
|
current_aspect_ratio: float = current_width / current_height
|
||||||
|
if aspect_ratio > current_aspect_ratio:
|
||||||
|
new_height = (height * current_width) // width
|
||||||
|
current_height = new_height
|
||||||
|
else:
|
||||||
|
new_width = (width * current_height) // height
|
||||||
|
current_width = new_width
|
||||||
|
|
||||||
|
unpadded_features = current_height * current_width
|
||||||
|
newline_features = current_height
|
||||||
|
return (unpadded_features, newline_features)
|
||||||
|
|
||||||
|
|
||||||
def get_number_of_features(height: int, width: int, config) -> int:
|
def get_number_of_features(height: int, width: int, config) -> int:
|
||||||
# From config
|
# From config
|
||||||
# Hardcoded for CLIP for now
|
# Hardcoded for CLIP for now
|
||||||
@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
|||||||
image_grid_pinpoints,
|
image_grid_pinpoints,
|
||||||
image_size,
|
image_size,
|
||||||
)
|
)
|
||||||
|
unpadded_features, newline_features = get_unpadded_features(
|
||||||
height_of_patch = math.ceil(height / width * npatches)
|
height, width, npatches, num_patch_height, num_patch_width
|
||||||
|
)
|
||||||
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
|
|
||||||
# They are only added after width
|
|
||||||
newline_features = height_of_patch * num_patch_width
|
|
||||||
# The base patch covers the entire image
|
# The base patch covers the entire image
|
||||||
base_features = npatches**2
|
base_features = npatches**2
|
||||||
return unpadded_features + newline_features + base_features
|
return unpadded_features + newline_features + base_features
|
||||||
@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
|
|
||||||
# assert get_number_of_features(640, 640) == 2928
|
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLMBatch(FlashMistralBatch):
|
class VlmCausalLMBatch(FlashMistralBatch):
|
||||||
pixel_values: Optional[List[torch.Tensor]]
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
image_sizes: Optional[List[Tuple[int, int]]]
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
def concatenate(cls, batches):
|
def concatenate(cls, batches):
|
||||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
def filter(self, request_ids: List[int]):
|
def filter(self, request_ids: List[int]):
|
||||||
batch = super().filter(request_ids)
|
batch = super().filter(request_ids)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
for r in requests:
|
for r in requests:
|
||||||
chunks = split(r.inputs)
|
chunks = split(r.inputs)
|
||||||
full_text = ""
|
full_text = ""
|
||||||
|
image_id = 0
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if chunk["type"] == "text":
|
if chunk["type"] == "text":
|
||||||
full_text += chunk["content"]
|
full_text += chunk["content"]
|
||||||
@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
"Cannot process input image not starting with data:"
|
"Cannot process input image not starting with data:"
|
||||||
)
|
)
|
||||||
image_input = processor.image_processor(image, return_tensors="pt")
|
image_input = processor.image_processor(image, return_tensors="pt")
|
||||||
height, width = image_input["image_sizes"][0]
|
full_text += image_text_replacement(image_input, config, image_id)
|
||||||
num_features = get_number_of_features(height, width, config)
|
|
||||||
full_text += "<image>" * num_features
|
|
||||||
image_inputs.append(image_input)
|
image_inputs.append(image_input)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||||
@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
batch_inputs, truncation=True, max_length=max_truncation
|
batch_inputs, truncation=True, max_length=max_truncation
|
||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
if image_inputs:
|
if image_inputs:
|
||||||
image_inputs = {
|
image_input = image_inputs[0]
|
||||||
|
new_image_inputs = {
|
||||||
"pixel_values": torch.cat(
|
"pixel_values": torch.cat(
|
||||||
[img["pixel_values"] for img in image_inputs], dim=0
|
[img["pixel_values"] for img in image_inputs], dim=0
|
||||||
),
|
),
|
||||||
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
|
|
||||||
}
|
}
|
||||||
|
if "pixel_attention_mask" in image_input:
|
||||||
|
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
||||||
|
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
||||||
|
)
|
||||||
|
if "image_sizes" in image_input:
|
||||||
|
new_image_inputs["image_sizes"] = torch.cat(
|
||||||
|
[img["image_sizes"] for img in image_inputs], dim=0
|
||||||
|
)
|
||||||
|
image_inputs = new_image_inputs
|
||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
return batch_tokenized_inputs, image_inputs
|
return batch_tokenized_inputs, image_inputs
|
||||||
@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
if image_inputs is not None:
|
if image_inputs is not None:
|
||||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
if "pixel_attention_mask" in image_inputs:
|
||||||
|
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
if "image_sizes" in image_inputs:
|
||||||
|
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||||
|
else:
|
||||||
|
batch.image_sizes = None
|
||||||
else:
|
else:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||||
return VlmCausalLMBatch
|
return VlmCausalLMBatch
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.language_model.model.layers),
|
|
||||||
model.language_model.model.num_key_value_heads,
|
|
||||||
model.language_model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.language_model, "max_past", None)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, batch: VlmCausalLMBatch
|
self, batch: VlmCausalLMBatch
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
max_s = min(self.max_past(), max_s)
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
padded_bs = bs
|
|
||||||
if bs == 3:
|
|
||||||
padded_bs = 4
|
|
||||||
elif 3 < bs <= 8:
|
|
||||||
padded_bs = 8
|
|
||||||
elif bs > 8:
|
|
||||||
padded_bs = (bs + 7) // 8 * 8
|
|
||||||
|
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
bs = input_ids.shape[0]
|
||||||
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
|
if sorted_padded_bs:
|
||||||
|
# Get associated cuda graph
|
||||||
|
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||||
|
else:
|
||||||
|
cuda_graph = None
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
pixel_values=batch.pixel_values,
|
pixel_values=batch.pixel_values,
|
||||||
|
pixel_attention_mask=batch.pixel_attention_mask,
|
||||||
image_sizes=batch.image_sizes,
|
image_sizes=batch.image_sizes,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
if batch.pixel_attention_mask is not None:
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
if batch.image_sizes is not None:
|
if batch.image_sizes is not None:
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -756,6 +756,19 @@ class TensorParallelHead(SuperLayer):
|
|||||||
|
|
||||||
|
|
||||||
class TensorParallelColumnLinear(SuperLayer):
|
class TensorParallelColumnLinear(SuperLayer):
|
||||||
|
@classmethod
|
||||||
|
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
||||||
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
|
weight = weights.get_weights_col_packed_gate_up(
|
||||||
|
prefix, quantize=config.quantize
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
raise NotImplementedError("packed_gate_up only implemented without bias")
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
linear = get_linear(weight, bias, config.quantize)
|
||||||
|
return cls(linear)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
||||||
"""Specific method when the QKV was joined after the fact"""
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
|
@ -143,6 +143,8 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
score = torch.gather(scores, 1, input_ids)
|
score = torch.gather(scores, 1, input_ids)
|
||||||
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
||||||
score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
|
score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||||
|
# set score to 0 where input_ids is a padding token
|
||||||
|
score *= input_ids.ne(0)
|
||||||
|
|
||||||
return scores.scatter_add_(1, input_ids, score)
|
return scores.scatter_add_(1, input_ids, score)
|
||||||
|
|
||||||
@ -168,6 +170,8 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
score = -torch.where(
|
score = -torch.where(
|
||||||
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
||||||
)
|
)
|
||||||
|
# set score to 0 where input_ids is a padding token
|
||||||
|
score *= input_ids.ne(0)
|
||||||
|
|
||||||
return scores.scatter_add_(1, input_ids, score)
|
return scores.scatter_add_(1, input_ids, score)
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
# vllm imports
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||||
from vllm._C import cache_ops, ops
|
|
||||||
|
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
@ -13,7 +12,18 @@ def reshape_and_cache(
|
|||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
):
|
):
|
||||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
if IS_CUDA_SYSTEM:
|
||||||
|
from vllm._C import cache_ops
|
||||||
|
|
||||||
|
cache_ops.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||||
|
)
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
from vllm import cache_ops
|
||||||
|
|
||||||
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||||
|
else:
|
||||||
|
raise ValueError("vllm is not supported on your system")
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
@ -55,21 +65,43 @@ def attention(
|
|||||||
# to parallelize.
|
# to parallelize.
|
||||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||||
if use_v1:
|
if use_v1:
|
||||||
ops.paged_attention_v1(
|
if IS_CUDA_SYSTEM:
|
||||||
out,
|
from vllm._C import ops
|
||||||
query,
|
|
||||||
key_cache,
|
ops.paged_attention_v1(
|
||||||
value_cache,
|
out,
|
||||||
kv_head_mapping,
|
query,
|
||||||
softmax_scale,
|
key_cache,
|
||||||
block_tables,
|
value_cache,
|
||||||
input_lengths,
|
kv_head_mapping,
|
||||||
block_size,
|
softmax_scale,
|
||||||
max_s,
|
block_tables,
|
||||||
None,
|
input_lengths,
|
||||||
"auto",
|
block_size,
|
||||||
1.0,
|
max_s,
|
||||||
)
|
None,
|
||||||
|
"auto",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
from vllm import attention_ops
|
||||||
|
|
||||||
|
attention_ops.paged_attention_v1(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("vllm is not supported on your system")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
assert _PARTITION_SIZE % block_size == 0
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
@ -84,21 +116,46 @@ def attention(
|
|||||||
device=out.device,
|
device=out.device,
|
||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.empty_like(exp_sums)
|
||||||
ops.paged_attention_v2(
|
|
||||||
out,
|
if IS_CUDA_SYSTEM:
|
||||||
exp_sums,
|
from vllm._C import ops
|
||||||
max_logits,
|
|
||||||
tmp_output,
|
ops.paged_attention_v2(
|
||||||
query,
|
out,
|
||||||
key_cache,
|
exp_sums,
|
||||||
value_cache,
|
max_logits,
|
||||||
kv_head_mapping,
|
tmp_output,
|
||||||
softmax_scale,
|
query,
|
||||||
block_tables,
|
key_cache,
|
||||||
input_lengths,
|
value_cache,
|
||||||
block_size,
|
kv_head_mapping,
|
||||||
max_s,
|
softmax_scale,
|
||||||
None,
|
block_tables,
|
||||||
"auto",
|
input_lengths,
|
||||||
1.0,
|
block_size,
|
||||||
)
|
max_s,
|
||||||
|
None,
|
||||||
|
"auto",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
from vllm import attention_ops
|
||||||
|
|
||||||
|
attention_ops.paged_attention_v2(
|
||||||
|
out,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("vllm is not supported on your system")
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Set, Union
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
@ -143,12 +143,22 @@ class StopSequenceCriteria:
|
|||||||
class StoppingCriteria:
|
class StoppingCriteria:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
eos_token_id: int,
|
eos_token_ids: Optional[Union[Set[int], int]],
|
||||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||||
max_new_tokens: int = 20,
|
max_new_tokens: int = 20,
|
||||||
ignore_eos_token: bool = False,
|
ignore_eos_token: bool = False,
|
||||||
):
|
):
|
||||||
self.eos_token_id = eos_token_id
|
if eos_token_ids is None:
|
||||||
|
eos_token_ids = set()
|
||||||
|
elif isinstance(eos_token_ids, int):
|
||||||
|
eos_token_ids = set([eos_token_ids])
|
||||||
|
elif isinstance(eos_token_ids, set):
|
||||||
|
eos_token_ids = eos_token_ids
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
|
||||||
|
)
|
||||||
|
self.eos_token_ids = eos_token_ids
|
||||||
self.stop_sequence_criterias = stop_sequence_criterias
|
self.stop_sequence_criterias = stop_sequence_criterias
|
||||||
self.max_new_tokens = max_new_tokens
|
self.max_new_tokens = max_new_tokens
|
||||||
self.current_tokens = 0
|
self.current_tokens = 0
|
||||||
@ -160,7 +170,10 @@ class StoppingCriteria:
|
|||||||
if self.current_tokens >= self.max_new_tokens:
|
if self.current_tokens >= self.max_new_tokens:
|
||||||
return True, FinishReason.FINISH_REASON_LENGTH
|
return True, FinishReason.FINISH_REASON_LENGTH
|
||||||
|
|
||||||
if not self.ignore_eos_token and last_token == self.eos_token_id:
|
if isinstance(last_token, torch.Tensor):
|
||||||
|
last_token = last_token.item()
|
||||||
|
|
||||||
|
if not self.ignore_eos_token and last_token in self.eos_token_ids:
|
||||||
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
||||||
|
|
||||||
if self.stop_sequence_criterias:
|
if self.stop_sequence_criterias:
|
||||||
@ -184,8 +197,10 @@ class StoppingCriteria:
|
|||||||
stop_sequence_criterias = [
|
stop_sequence_criterias = [
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
]
|
]
|
||||||
|
# TODO Hack because eos_token_id cannot be what we want.
|
||||||
|
eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
|
||||||
return StoppingCriteria(
|
return StoppingCriteria(
|
||||||
tokenizer.eos_token_id,
|
eos_token_id,
|
||||||
stop_sequence_criterias,
|
stop_sequence_criterias,
|
||||||
pb.max_new_tokens,
|
pb.max_new_tokens,
|
||||||
pb.ignore_eos_token,
|
pb.ignore_eos_token,
|
||||||
@ -273,7 +288,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if any([x != 1.0 for x in temperature]):
|
if any(x != 1.0 for x in temperature):
|
||||||
do_sample = [
|
do_sample = [
|
||||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||||
]
|
]
|
||||||
@ -281,15 +296,15 @@ class HeterogeneousNextTokenChooser:
|
|||||||
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
||||||
)
|
)
|
||||||
|
|
||||||
if any([x != 0 for x in top_k]):
|
if any(x != 0 for x in top_k):
|
||||||
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
||||||
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
||||||
|
|
||||||
if any([x < 1.0 for x in top_p]):
|
if any(x < 1.0 for x in top_p):
|
||||||
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
||||||
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
|
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
|
||||||
|
|
||||||
if any([x < 1.0 for x in typical_p]):
|
if any(x < 1.0 for x in typical_p):
|
||||||
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
|
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
|
||||||
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
|
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
|
||||||
|
|
||||||
|
@ -141,6 +141,12 @@ class Weights:
|
|||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
|
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
|
||||||
|
return self.get_weights_col_packed(prefix, quantize, 3)
|
||||||
|
|
||||||
|
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
|
||||||
|
return self.get_weights_col_packed(prefix, quantize, 2)
|
||||||
|
|
||||||
|
def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
|
||||||
"""
|
"""
|
||||||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||||
already alternating Q,K,V within the main tensor
|
already alternating Q,K,V within the main tensor
|
||||||
@ -181,8 +187,8 @@ class Weights:
|
|||||||
else:
|
else:
|
||||||
slice_ = self._get_slice(f"{prefix}.weight")
|
slice_ = self._get_slice(f"{prefix}.weight")
|
||||||
total_size = slice_.get_shape()[0]
|
total_size = slice_.get_shape()[0]
|
||||||
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
|
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
||||||
single_size = total_size // 3
|
single_size = total_size // blocks
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
rank = self.process_group.rank()
|
rank = self.process_group.rank()
|
||||||
|
|
||||||
@ -192,10 +198,11 @@ class Weights:
|
|||||||
block_size = single_size // world_size
|
block_size = single_size // world_size
|
||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
q = slice_[start:stop]
|
tensors = []
|
||||||
k = slice_[start + single_size : stop + single_size]
|
for i in range(blocks):
|
||||||
v = slice_[start + 2 * single_size : stop + 2 * single_size]
|
tensor = slice_[start + i * single_size : stop + i * single_size]
|
||||||
weight = torch.cat([q, k, v], dim=0)
|
tensors.append(tensor)
|
||||||
|
weight = torch.cat(tensors, dim=0)
|
||||||
weight = weight.to(device=self.device)
|
weight = weight.to(device=self.device)
|
||||||
weight = weight.to(dtype=self.dtype)
|
weight = weight.to(dtype=self.dtype)
|
||||||
return weight
|
return weight
|
||||||
|
Loading…
Reference in New Issue
Block a user