mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-15 13:52:06 +00:00
Speculative (#1308)
This commit is contained in:
parent
a41c1a6bc7
commit
a7f52f3812
254
Cargo.lock
generated
254
Cargo.lock
generated
@ -88,9 +88,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anyhow"
|
name = "anyhow"
|
||||||
version = "1.0.81"
|
version = "1.0.82"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247"
|
checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "arc-swap"
|
name = "arc-swap"
|
||||||
@ -128,18 +128,18 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-trait"
|
name = "async-trait"
|
||||||
version = "0.1.79"
|
version = "0.1.80"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681"
|
checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -288,9 +288,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bumpalo"
|
name = "bumpalo"
|
||||||
version = "3.15.4"
|
version = "3.16.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa"
|
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bytecount"
|
name = "bytecount"
|
||||||
@ -350,9 +350,9 @@ checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.0.90"
|
version = "1.0.94"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5"
|
checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cfg-if"
|
name = "cfg-if"
|
||||||
@ -397,7 +397,7 @@ dependencies = [
|
|||||||
"heck 0.5.0",
|
"heck 0.5.0",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -675,9 +675,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "either"
|
name = "either"
|
||||||
version = "1.10.0"
|
version = "1.11.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
|
checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "encode_unicode"
|
name = "encode_unicode"
|
||||||
@ -687,9 +687,9 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "encoding_rs"
|
name = "encoding_rs"
|
||||||
version = "0.8.33"
|
version = "0.8.34"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1"
|
checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
]
|
]
|
||||||
@ -839,7 +839,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -893,9 +893,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "getrandom"
|
name = "getrandom"
|
||||||
version = "0.2.12"
|
version = "0.2.14"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
|
checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"libc",
|
"libc",
|
||||||
@ -920,9 +920,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "h2"
|
name = "h2"
|
||||||
version = "0.3.25"
|
version = "0.3.26"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4fbd2820c5e49886948654ab546d0688ff24530286bdcf8fca3cefb16d4618eb"
|
checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"fnv",
|
"fnv",
|
||||||
@ -993,15 +993,6 @@ dependencies = [
|
|||||||
"ureq",
|
"ureq",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "home"
|
|
||||||
version = "0.5.9"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5"
|
|
||||||
dependencies = [
|
|
||||||
"windows-sys 0.52.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hostname"
|
name = "hostname"
|
||||||
version = "0.3.1"
|
version = "0.3.1"
|
||||||
@ -1204,6 +1195,15 @@ dependencies = [
|
|||||||
"either",
|
"either",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itertools"
|
||||||
|
version = "0.12.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "itoa"
|
name = "itoa"
|
||||||
version = "1.0.11"
|
version = "1.0.11"
|
||||||
@ -1358,7 +1358,7 @@ checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1421,9 +1421,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "monostate"
|
name = "monostate"
|
||||||
version = "0.1.11"
|
version = "0.1.12"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "878c2a1f1c70e5724fa28f101ca787b6a7e8ad5c5e4ae4ca3b0fa4a419fa9075"
|
checksum = "a20fffcd8ca4c69d31e036a71abc400147b41f90895df4edcb36497a1f8af8bf"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"monostate-impl",
|
"monostate-impl",
|
||||||
"serde",
|
"serde",
|
||||||
@ -1431,20 +1431,20 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "monostate-impl"
|
name = "monostate-impl"
|
||||||
version = "0.1.11"
|
version = "0.1.12"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f686d68a09079e63b1d2c64aa305095887ce50565f00a922ebfaeeee0d9ba6ce"
|
checksum = "bf307cbbbd777a9c10cec88ddafee572b3484caad5cce0c9236523c3803105a6"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "multimap"
|
name = "multimap"
|
||||||
version = "0.8.3"
|
version = "0.10.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
|
checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "muxado"
|
name = "muxado"
|
||||||
@ -1662,7 +1662,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1878,7 +1878,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1919,12 +1919,12 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prettyplease"
|
name = "prettyplease"
|
||||||
version = "0.2.17"
|
version = "0.2.19"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7"
|
checksum = "5ac2cf0f2e4f42b49f5ffd07dae8d746508ef7526c13940e5f524012ae6c6550"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1953,9 +1953,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro2"
|
name = "proc-macro2"
|
||||||
version = "1.0.79"
|
version = "1.0.81"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e"
|
checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
@ -1972,34 +1972,33 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prost"
|
name = "prost"
|
||||||
version = "0.12.3"
|
version = "0.12.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "146c289cda302b98a28d40c8b3b90498d6e526dd24ac2ecea73e4e491685b94a"
|
checksum = "d0f5d036824e4761737860779c906171497f6d55681139d8312388f8fe398922"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"prost-derive 0.12.3",
|
"prost-derive 0.12.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prost-build"
|
name = "prost-build"
|
||||||
version = "0.12.3"
|
version = "0.12.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c55e02e35260070b6f716a2423c2ff1c3bb1642ddca6f99e1f26d06268a0e2d2"
|
checksum = "80b776a1b2dc779f5ee0641f8ade0125bc1298dd41a9a0c16d8bd57b42d222b1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"heck 0.4.1",
|
"heck 0.5.0",
|
||||||
"itertools 0.11.0",
|
"itertools 0.12.1",
|
||||||
"log",
|
"log",
|
||||||
"multimap",
|
"multimap",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"petgraph",
|
"petgraph",
|
||||||
"prettyplease",
|
"prettyplease",
|
||||||
"prost 0.12.3",
|
"prost 0.12.4",
|
||||||
"prost-types",
|
"prost-types",
|
||||||
"regex",
|
"regex",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"which",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2017,24 +2016,24 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prost-derive"
|
name = "prost-derive"
|
||||||
version = "0.12.3"
|
version = "0.12.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "efb6c9a1dd1def8e2124d17e83a20af56f1570d6c2d2bd9e266ccb768df3840e"
|
checksum = "19de2de2a00075bf566bee3bd4db014b11587e84184d3f7a791bc17f1a8e9e48"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"itertools 0.11.0",
|
"itertools 0.12.1",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prost-types"
|
name = "prost-types"
|
||||||
version = "0.12.3"
|
version = "0.12.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "193898f59edcf43c26227dcd4c8427f00d99d61e95dcde58dabd49fa291d470e"
|
checksum = "3235c33eb02c1f1e212abdbe34c78b264b038fb58ca612664343271e36e55ffe"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"prost 0.12.3",
|
"prost 0.12.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2055,9 +2054,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quote"
|
name = "quote"
|
||||||
version = "1.0.35"
|
version = "1.0.36"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef"
|
checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
]
|
]
|
||||||
@ -2310,7 +2309,7 @@ dependencies = [
|
|||||||
"quote",
|
"quote",
|
||||||
"rust-embed-utils",
|
"rust-embed-utils",
|
||||||
"shellexpand",
|
"shellexpand",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
"walkdir",
|
"walkdir",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -2406,9 +2405,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustversion"
|
name = "rustversion"
|
||||||
version = "1.0.14"
|
version = "1.0.15"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4"
|
checksum = "80af6f9131f277a45a3fba6ce8e2258037bb0477a67e610d3c1fe046ab31de47"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ryu"
|
name = "ryu"
|
||||||
@ -2484,29 +2483,29 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde"
|
name = "serde"
|
||||||
version = "1.0.197"
|
version = "1.0.198"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2"
|
checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_derive"
|
name = "serde_derive"
|
||||||
version = "1.0.197"
|
version = "1.0.198"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b"
|
checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_json"
|
name = "serde_json"
|
||||||
version = "1.0.115"
|
version = "1.0.116"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd"
|
checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"itoa",
|
"itoa",
|
||||||
"ryu",
|
"ryu",
|
||||||
@ -2689,7 +2688,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"rustversion",
|
"rustversion",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2711,9 +2710,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "2.0.58"
|
version = "2.0.60"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687"
|
checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@ -2728,9 +2727,9 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sysinfo"
|
name = "sysinfo"
|
||||||
version = "0.30.8"
|
version = "0.30.10"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4b1a378e48fb3ce3a5cf04359c456c9c98ff689bcf1c1bc6e6a31f247686f275"
|
checksum = "26d7c217777061d5a2d652aea771fb9ba98b6dade657204b08c4b9604d11555b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"core-foundation-sys",
|
"core-foundation-sys",
|
||||||
@ -2824,7 +2823,7 @@ version = "1.2.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"grpc-metadata",
|
"grpc-metadata",
|
||||||
"prost 0.12.3",
|
"prost 0.12.4",
|
||||||
"prost-build",
|
"prost-build",
|
||||||
"rand",
|
"rand",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
@ -2903,7 +2902,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2918,9 +2917,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "time"
|
name = "time"
|
||||||
version = "0.3.34"
|
version = "0.3.36"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749"
|
checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deranged",
|
"deranged",
|
||||||
"itoa",
|
"itoa",
|
||||||
@ -2941,9 +2940,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "time-macros"
|
name = "time-macros"
|
||||||
version = "0.2.17"
|
version = "0.2.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774"
|
checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"num-conv",
|
"num-conv",
|
||||||
"time-core",
|
"time-core",
|
||||||
@ -3035,7 +3034,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3131,7 +3130,7 @@ dependencies = [
|
|||||||
"hyper-timeout",
|
"hyper-timeout",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"pin-project",
|
"pin-project",
|
||||||
"prost 0.12.3",
|
"prost 0.12.4",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tower",
|
"tower",
|
||||||
@ -3150,7 +3149,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"prost-build",
|
"prost-build",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3223,7 +3222,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3464,7 +3463,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"regex",
|
"regex",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3563,7 +3562,7 @@ dependencies = [
|
|||||||
"once_cell",
|
"once_cell",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -3597,7 +3596,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
"wasm-bindgen-backend",
|
"wasm-bindgen-backend",
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
@ -3637,18 +3636,6 @@ dependencies = [
|
|||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "which"
|
|
||||||
version = "4.4.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
|
|
||||||
dependencies = [
|
|
||||||
"either",
|
|
||||||
"home",
|
|
||||||
"once_cell",
|
|
||||||
"rustix",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "winapi"
|
name = "winapi"
|
||||||
version = "0.3.9"
|
version = "0.3.9"
|
||||||
@ -3687,7 +3674,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be"
|
checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-core",
|
"windows-core",
|
||||||
"windows-targets 0.52.4",
|
"windows-targets 0.52.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3696,7 +3683,7 @@ version = "0.52.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
|
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-targets 0.52.4",
|
"windows-targets 0.52.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3723,7 +3710,7 @@ version = "0.52.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-targets 0.52.4",
|
"windows-targets 0.52.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3758,17 +3745,18 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-targets"
|
name = "windows-targets"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b"
|
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows_aarch64_gnullvm 0.52.4",
|
"windows_aarch64_gnullvm 0.52.5",
|
||||||
"windows_aarch64_msvc 0.52.4",
|
"windows_aarch64_msvc 0.52.5",
|
||||||
"windows_i686_gnu 0.52.4",
|
"windows_i686_gnu 0.52.5",
|
||||||
"windows_i686_msvc 0.52.4",
|
"windows_i686_gnullvm",
|
||||||
"windows_x86_64_gnu 0.52.4",
|
"windows_i686_msvc 0.52.5",
|
||||||
"windows_x86_64_gnullvm 0.52.4",
|
"windows_x86_64_gnu 0.52.5",
|
||||||
"windows_x86_64_msvc 0.52.4",
|
"windows_x86_64_gnullvm 0.52.5",
|
||||||
|
"windows_x86_64_msvc 0.52.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3785,9 +3773,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_aarch64_gnullvm"
|
name = "windows_aarch64_gnullvm"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9"
|
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_aarch64_msvc"
|
name = "windows_aarch64_msvc"
|
||||||
@ -3803,9 +3791,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_aarch64_msvc"
|
name = "windows_aarch64_msvc"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675"
|
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_gnu"
|
name = "windows_i686_gnu"
|
||||||
@ -3821,9 +3809,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_gnu"
|
name = "windows_i686_gnu"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3"
|
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_gnullvm"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_msvc"
|
name = "windows_i686_msvc"
|
||||||
@ -3839,9 +3833,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_msvc"
|
name = "windows_i686_msvc"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02"
|
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnu"
|
name = "windows_x86_64_gnu"
|
||||||
@ -3857,9 +3851,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnu"
|
name = "windows_x86_64_gnu"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03"
|
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnullvm"
|
name = "windows_x86_64_gnullvm"
|
||||||
@ -3875,9 +3869,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnullvm"
|
name = "windows_x86_64_gnullvm"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177"
|
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_msvc"
|
name = "windows_x86_64_msvc"
|
||||||
@ -3893,9 +3887,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_msvc"
|
name = "windows_x86_64_msvc"
|
||||||
version = "0.52.4"
|
version = "0.52.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
|
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "winreg"
|
name = "winreg"
|
||||||
@ -3924,7 +3918,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.58",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -67,6 +67,14 @@ Options:
|
|||||||
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
|
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
|
||||||
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
|
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
|
||||||
|
|
||||||
|
```
|
||||||
|
## SPECULATE
|
||||||
|
```shell
|
||||||
|
--speculate <SPECULATE>
|
||||||
|
The number of input_ids to speculate on If using a medusa model, the heads will be picked up automatically Other wise, it will use n-gram speculation which is relatively free in terms of compute, but the speedup heavily depends on the task
|
||||||
|
|
||||||
|
[env: SPECULATE=]
|
||||||
|
|
||||||
```
|
```
|
||||||
## DTYPE
|
## DTYPE
|
||||||
```shell
|
```shell
|
||||||
|
@ -0,0 +1,98 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -10.0078125,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21784,
|
||||||
|
"logprob": -15.515625,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29257,
|
||||||
|
"logprob": -2.8847656,
|
||||||
|
"text": "Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29973,
|
||||||
|
"logprob": -4.140625,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.1582031,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2772,
|
||||||
|
"logprob": -0.23083496,
|
||||||
|
"special": false,
|
||||||
|
"text": "De"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1022,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "ep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6509,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29892,
|
||||||
|
"logprob": -0.61816406,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 607,
|
||||||
|
"logprob": -0.7089844,
|
||||||
|
"special": false,
|
||||||
|
"text": " which"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 508,
|
||||||
|
"logprob": -1.7724609,
|
||||||
|
"special": false,
|
||||||
|
"text": " can"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 367,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " be"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5545,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " considered"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 408,
|
||||||
|
"logprob": -0.3869629,
|
||||||
|
"special": false,
|
||||||
|
"text": " as"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "What is Deep Learning?\nDeep learning, which can be considered as"
|
||||||
|
}
|
@ -0,0 +1,414 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1724,
|
||||||
|
"logprob": -10.734375,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -1.5488281,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21784,
|
||||||
|
"logprob": -9.2890625,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29257,
|
||||||
|
"logprob": -1.2753906,
|
||||||
|
"text": "Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29973,
|
||||||
|
"logprob": -0.48046875,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.1845703,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2772,
|
||||||
|
"logprob": -0.5727539,
|
||||||
|
"special": false,
|
||||||
|
"text": "De"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1022,
|
||||||
|
"logprob": -0.00010967255,
|
||||||
|
"special": false,
|
||||||
|
"text": "ep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6509,
|
||||||
|
"logprob": -0.1239624,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -0.04510498,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.018295288,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11306,
|
||||||
|
"logprob": -0.45922852,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.00020992756,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4933,
|
||||||
|
"logprob": -0.0046539307,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6509,
|
||||||
|
"logprob": -0.00025844574,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "\nDeep learning is a subset of machine learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1724,
|
||||||
|
"logprob": -10.734375,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -1.5488281,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21784,
|
||||||
|
"logprob": -9.2890625,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29257,
|
||||||
|
"logprob": -1.2724609,
|
||||||
|
"text": "Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29973,
|
||||||
|
"logprob": -0.47729492,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.1826172,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2772,
|
||||||
|
"logprob": -0.56689453,
|
||||||
|
"special": false,
|
||||||
|
"text": "De"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1022,
|
||||||
|
"logprob": -0.000108003616,
|
||||||
|
"special": false,
|
||||||
|
"text": "ep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6509,
|
||||||
|
"logprob": -0.1239624,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -0.044433594,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.018295288,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11306,
|
||||||
|
"logprob": -0.45922852,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.0002104044,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4933,
|
||||||
|
"logprob": -0.004711151,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6509,
|
||||||
|
"logprob": -0.00025892258,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "\nDeep learning is a subset of machine learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1724,
|
||||||
|
"logprob": -10.734375,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -1.5488281,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21784,
|
||||||
|
"logprob": -9.2890625,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29257,
|
||||||
|
"logprob": -1.2724609,
|
||||||
|
"text": "Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29973,
|
||||||
|
"logprob": -0.47729492,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.1826172,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2772,
|
||||||
|
"logprob": -0.56689453,
|
||||||
|
"special": false,
|
||||||
|
"text": "De"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1022,
|
||||||
|
"logprob": -0.000108003616,
|
||||||
|
"special": false,
|
||||||
|
"text": "ep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6509,
|
||||||
|
"logprob": -0.1239624,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -0.044433594,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.018295288,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11306,
|
||||||
|
"logprob": -0.45922852,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.0002104044,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4933,
|
||||||
|
"logprob": -0.004711151,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6509,
|
||||||
|
"logprob": -0.00025892258,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "\nDeep learning is a subset of machine learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1724,
|
||||||
|
"logprob": -10.734375,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -1.5488281,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21784,
|
||||||
|
"logprob": -9.2890625,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29257,
|
||||||
|
"logprob": -1.2724609,
|
||||||
|
"text": "Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29973,
|
||||||
|
"logprob": -0.47729492,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.1826172,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2772,
|
||||||
|
"logprob": -0.56689453,
|
||||||
|
"special": false,
|
||||||
|
"text": "De"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1022,
|
||||||
|
"logprob": -0.000108003616,
|
||||||
|
"special": false,
|
||||||
|
"text": "ep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6509,
|
||||||
|
"logprob": -0.1239624,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -0.044433594,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.018295288,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11306,
|
||||||
|
"logprob": -0.45922852,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.0002104044,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4933,
|
||||||
|
"logprob": -0.004711151,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6509,
|
||||||
|
"logprob": -0.00025892258,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "\nDeep learning is a subset of machine learning"
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,103 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1724,
|
||||||
|
"logprob": -10.734375,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -1.5488281,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21784,
|
||||||
|
"logprob": -9.2890625,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29257,
|
||||||
|
"logprob": -1.2753906,
|
||||||
|
"text": "Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29973,
|
||||||
|
"logprob": -0.48046875,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.1845703,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2772,
|
||||||
|
"logprob": -0.5727539,
|
||||||
|
"special": false,
|
||||||
|
"text": "De"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1022,
|
||||||
|
"logprob": -0.000108122826,
|
||||||
|
"special": false,
|
||||||
|
"text": "ep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6509,
|
||||||
|
"logprob": -0.1239624,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -0.044433594,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.01852417,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11306,
|
||||||
|
"logprob": -0.45922852,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.0002104044,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4933,
|
||||||
|
"logprob": -0.004787445,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6509,
|
||||||
|
"logprob": -0.00026226044,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "\nDeep learning is a subset of machine learning"
|
||||||
|
}
|
59
integration-tests/models/test_flash_medusa.py
Normal file
59
integration-tests/models/test_flash_medusa.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_medusa_handle(launcher):
|
||||||
|
with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_medusa(flash_medusa_handle):
|
||||||
|
await flash_medusa_handle.health(300)
|
||||||
|
return flash_medusa_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_medusa_simple(flash_medusa, response_snapshot):
|
||||||
|
response = await flash_medusa.generate(
|
||||||
|
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
|
||||||
|
response = await flash_medusa.generate(
|
||||||
|
"What is Deep Learning?",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
|
||||||
|
assert responses[0].generated_text == '\nDeep learning is a subset of machine learning'
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -21,6 +21,7 @@ async def test_flash_mistral(flash_mistral, response_snapshot):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response.generated_text == ": Let n = 10 - 1"
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -55,6 +56,7 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
|
||||||
|
assert responses[0].generated_text == ": Let n = 10 - 1"
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
@ -157,6 +157,13 @@ struct Args {
|
|||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
|
|
||||||
|
/// The number of input_ids to speculate on
|
||||||
|
/// If using a medusa model, the heads will be picked up automatically
|
||||||
|
/// Other wise, it will use n-gram speculation which is relatively free
|
||||||
|
/// in terms of compute, but the speedup heavily depends on the task.
|
||||||
|
#[clap(long, env)]
|
||||||
|
speculate: Option<usize>,
|
||||||
|
|
||||||
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
|
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
|
||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
@ -377,6 +384,7 @@ fn shard_manager(
|
|||||||
model_id: String,
|
model_id: String,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
|
speculate: Option<usize>,
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
@ -435,6 +443,11 @@ fn shard_manager(
|
|||||||
shard_args.push(quantize.to_string())
|
shard_args.push(quantize.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(speculate) = speculate {
|
||||||
|
shard_args.push("--speculate".to_string());
|
||||||
|
shard_args.push(speculate.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(dtype) = dtype {
|
if let Some(dtype) = dtype {
|
||||||
shard_args.push("--dtype".to_string());
|
shard_args.push("--dtype".to_string());
|
||||||
shard_args.push(dtype.to_string())
|
shard_args.push(dtype.to_string())
|
||||||
@ -890,6 +903,7 @@ fn spawn_shards(
|
|||||||
let shutdown_sender = shutdown_sender.clone();
|
let shutdown_sender = shutdown_sender.clone();
|
||||||
let otlp_endpoint = args.otlp_endpoint.clone();
|
let otlp_endpoint = args.otlp_endpoint.clone();
|
||||||
let quantize = args.quantize;
|
let quantize = args.quantize;
|
||||||
|
let speculate = args.speculate;
|
||||||
let dtype = args.dtype;
|
let dtype = args.dtype;
|
||||||
let max_total_tokens = args.max_total_tokens;
|
let max_total_tokens = args.max_total_tokens;
|
||||||
let trust_remote_code = args.trust_remote_code;
|
let trust_remote_code = args.trust_remote_code;
|
||||||
@ -905,6 +919,7 @@ fn spawn_shards(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize,
|
quantize,
|
||||||
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
|
@ -7,7 +7,9 @@ const seed = 0;
|
|||||||
|
|
||||||
const host = __ENV.HOST || '127.0.0.1:8000';
|
const host = __ENV.HOST || '127.0.0.1:8000';
|
||||||
const timePerToken = new Trend('time_per_token', true);
|
const timePerToken = new Trend('time_per_token', true);
|
||||||
const throughput = new Counter('tokens_per_s');
|
const tokens = new Counter('tokens');
|
||||||
|
const new_tokens = new Counter('new_tokens');
|
||||||
|
const input_tokens = new Counter('input_tokens');
|
||||||
|
|
||||||
randomSeed(seed);
|
randomSeed(seed);
|
||||||
// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json"))
|
// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json"))
|
||||||
@ -19,7 +21,7 @@ export function get_options(reference_latency_ms){
|
|||||||
thresholds: {
|
thresholds: {
|
||||||
http_req_failed: ['rate==0'],
|
http_req_failed: ['rate==0'],
|
||||||
time_per_token: [{
|
time_per_token: [{
|
||||||
threshold: `p(50)<${3 * reference_latency_ms}`,
|
threshold: `p(50)<${5 * reference_latency_ms}`,
|
||||||
abortOnFail: true,
|
abortOnFail: true,
|
||||||
delayAbortEval: '10s'
|
delayAbortEval: '10s'
|
||||||
}],
|
}],
|
||||||
@ -28,7 +30,7 @@ export function get_options(reference_latency_ms){
|
|||||||
load_test: {
|
load_test: {
|
||||||
executor: 'constant-arrival-rate',
|
executor: 'constant-arrival-rate',
|
||||||
duration: '60s',
|
duration: '60s',
|
||||||
preAllocatedVUs: 100,
|
preAllocatedVUs: 10,
|
||||||
rate: 10,
|
rate: 10,
|
||||||
timeUnit: '1s',
|
timeUnit: '1s',
|
||||||
},
|
},
|
||||||
@ -48,17 +50,22 @@ export function run(host, generate_payload, max_new_tokens) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
check(res, {
|
check(res, {
|
||||||
'Post status is 200': (r) => res.status === 200,
|
'Post status is 200': (r) => res.status === 200,
|
||||||
});
|
});
|
||||||
const n_tokens = max_new_tokens;
|
const duration = res.timings.duration;
|
||||||
const timings = res.timings.duration;
|
|
||||||
|
|
||||||
if (res.status === 200) {
|
if (res.status === 200) {
|
||||||
const latency_ms_per_token = timings / n_tokens;
|
const body = res.json();
|
||||||
|
const n_tokens = body.details.tokens.length;
|
||||||
|
const latency_ms_per_token = duration / n_tokens;
|
||||||
timePerToken.add(latency_ms_per_token);
|
timePerToken.add(latency_ms_per_token);
|
||||||
const latency_in_s = latency_ms_per_token / 1000;
|
const latency_in_s = latency_ms_per_token / 1000;
|
||||||
const individual_throughput = 1 / latency_in_s;
|
const individual_throughput = 1 / latency_in_s;
|
||||||
throughput.add(individual_throughput);
|
const _input_tokens = body.details.prefill.length;
|
||||||
|
tokens.add(n_tokens + _input_tokens);
|
||||||
|
input_tokens.add(_input_tokens);
|
||||||
|
new_tokens.add(n_tokens);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import { get_options, run } from "./common.js";
|
import { get_options, run } from "./common.js";
|
||||||
|
|
||||||
const reference_latency_ms = 30;
|
const reference_latency_ms = 70;
|
||||||
const host = __ENV.HOST || '127.0.0.1:8000';
|
const host = __ENV.HOST || '127.0.0.1:8000';
|
||||||
const max_new_tokens = 50;
|
const max_new_tokens = 50;
|
||||||
|
|
||||||
|
|
||||||
function generate_payload(gpt){
|
function generate_payload(gpt){
|
||||||
const input = gpt["conversations"][0]["value"];
|
const input = gpt["conversations"][0]["value"];
|
||||||
return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "temperature" : 0.5}}
|
return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "decoder_input_details": true}}
|
||||||
}
|
}
|
||||||
|
|
||||||
export const options = get_options(reference_latency_ms);
|
export const options = get_options(reference_latency_ms);
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
package generate.v1;
|
package generate.v2;
|
||||||
|
|
||||||
service TextGenerationService {
|
service TextGenerationService {
|
||||||
/// Model Info
|
/// Model Info
|
||||||
@ -32,6 +32,7 @@ message InfoResponse {
|
|||||||
string dtype = 2;
|
string dtype = 2;
|
||||||
string device_type = 3;
|
string device_type = 3;
|
||||||
optional uint32 window_size = 4;
|
optional uint32 window_size = 4;
|
||||||
|
uint32 speculate = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty request
|
/// Empty request
|
||||||
@ -135,43 +136,27 @@ message GeneratedText {
|
|||||||
optional uint64 seed = 4;
|
optional uint64 seed = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message PrefillTokens {
|
message Tokens {
|
||||||
/// Prefill Token IDs
|
/// Token IDs
|
||||||
repeated uint32 ids = 1;
|
repeated uint32 ids = 1;
|
||||||
/// Prefill Logprobs
|
/// Logprobs
|
||||||
repeated float logprobs = 2;
|
repeated float logprobs = 2;
|
||||||
/// Prefill tokens
|
/// tokens
|
||||||
repeated string texts = 3;
|
repeated string texts = 3;
|
||||||
}
|
/// special
|
||||||
|
repeated bool is_special = 4;
|
||||||
message TopTokens {
|
|
||||||
/// Top Token IDs
|
|
||||||
repeated uint32 ids = 1;
|
|
||||||
/// Top Logprobs
|
|
||||||
repeated float logprobs = 2;
|
|
||||||
/// Top Token Texts
|
|
||||||
repeated string texts = 3;
|
|
||||||
/// If the tokens are special
|
|
||||||
repeated bool is_special = 6;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message Generation {
|
message Generation {
|
||||||
/// Request ID
|
/// Request ID
|
||||||
uint64 request_id = 1;
|
uint64 request_id = 1;
|
||||||
/// Prefill tokens (optional)
|
/// Prefill tokens (optional)
|
||||||
PrefillTokens prefill_tokens = 2;
|
Tokens prefill_tokens = 2;
|
||||||
/// Token ID
|
Tokens tokens = 3;
|
||||||
uint32 token_id = 3;
|
|
||||||
/// Logprob
|
|
||||||
float token_logprob = 4;
|
|
||||||
/// Text
|
|
||||||
string token_text = 5;
|
|
||||||
/// Is it a special token
|
|
||||||
bool token_is_special = 6;
|
|
||||||
/// Complete generated text
|
/// Complete generated text
|
||||||
optional GeneratedText generated_text = 7;
|
optional GeneratedText generated_text = 4;
|
||||||
/// Top tokens
|
/// Top tokens
|
||||||
TopTokens top_tokens = 8;
|
repeated Tokens top_tokens = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message FilterBatchRequest {
|
message FilterBatchRequest {
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||||
|
|
||||||
/// Single shard Client
|
/// Single shard Client
|
||||||
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
|
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
||||||
use crate::pb::generate::v1::*;
|
use crate::pb::generate::v2::*;
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use std::env;
|
use std::env;
|
||||||
use rand::{distributions::Uniform, Rng};
|
use rand::{distributions::Uniform, Rng};
|
||||||
|
@ -6,11 +6,11 @@ mod pb;
|
|||||||
mod sharded_client;
|
mod sharded_client;
|
||||||
|
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
pub use pb::generate::v1::HealthResponse;
|
pub use pb::generate::v2::HealthResponse;
|
||||||
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
||||||
pub use pb::generate::v1::{
|
pub use pb::generate::v2::{
|
||||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
||||||
PrefillTokens, Request, StoppingCriteriaParameters,
|
Request, StoppingCriteriaParameters, Tokens,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
@ -11,7 +11,7 @@ use std::sync::{
|
|||||||
Arc,
|
Arc,
|
||||||
};
|
};
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
@ -54,6 +54,7 @@ impl Infer {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
@ -62,7 +63,8 @@ impl Infer {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
16,
|
16,
|
||||||
window_size
|
window_size,
|
||||||
|
speculate
|
||||||
);
|
);
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
@ -533,50 +535,63 @@ fn send_responses(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create last Token
|
// Create last Token
|
||||||
let token = Token {
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
id: generation.token_id,
|
let n = tokens_.ids.len();
|
||||||
text: generation.token_text,
|
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
||||||
logprob: generation.token_logprob,
|
let mut iterator = tokens_
|
||||||
special: generation.token_is_special,
|
.ids
|
||||||
};
|
.into_iter()
|
||||||
|
.zip(tokens_.logprobs.into_iter())
|
||||||
// generation.top_tokens
|
.zip(tokens_.texts.into_iter())
|
||||||
|
.zip(tokens_.is_special.into_iter())
|
||||||
let mut top_tokens = Vec::new();
|
.enumerate()
|
||||||
if let Some(top_tokens_) = generation.top_tokens {
|
.peekable();
|
||||||
top_tokens.extend(
|
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||||
|
let token = Token {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
logprob,
|
||||||
|
special,
|
||||||
|
};
|
||||||
|
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||||
top_tokens_
|
top_tokens_
|
||||||
.ids
|
.ids
|
||||||
.into_iter()
|
.iter()
|
||||||
.zip(top_tokens_.logprobs.into_iter())
|
.zip(top_tokens_.logprobs.iter())
|
||||||
.zip(top_tokens_.texts.into_iter())
|
.zip(top_tokens_.texts.iter())
|
||||||
.zip(top_tokens_.is_special.into_iter())
|
.zip(top_tokens_.is_special.iter())
|
||||||
.map(|(((id, logprob), text), special)| Token {
|
.map(|(((&id, &logprob), text), &special)| Token {
|
||||||
id,
|
id,
|
||||||
text,
|
text: text.to_string(),
|
||||||
logprob,
|
logprob,
|
||||||
special,
|
special,
|
||||||
}),
|
})
|
||||||
)
|
.collect()
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
match (&generation.generated_text, iterator.peek()) {
|
||||||
|
(Some(generated_text), None) => {
|
||||||
|
// Generation has ended
|
||||||
|
stopped = true;
|
||||||
|
// Send message
|
||||||
|
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
top_tokens,
|
||||||
|
generated_text: generated_text.clone(),
|
||||||
|
queued: entry.queue_time,
|
||||||
|
start: entry.batch_time.unwrap(),
|
||||||
|
}))?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Send message
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(generated_text) = generation.generated_text {
|
|
||||||
// Generation has ended
|
|
||||||
stopped = true;
|
|
||||||
// Send message
|
|
||||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
|
||||||
token,
|
|
||||||
top_tokens,
|
|
||||||
generated_text,
|
|
||||||
queued: entry.queue_time,
|
|
||||||
start: entry.batch_time.unwrap(),
|
|
||||||
}))?;
|
|
||||||
} else {
|
|
||||||
// Send message
|
|
||||||
entry
|
|
||||||
.response_tx
|
|
||||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
|
||||||
}
|
|
||||||
Ok(stopped)
|
Ok(stopped)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -601,7 +616,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) enum InferStreamResponse {
|
pub(crate) enum InferStreamResponse {
|
||||||
// Optional first message
|
// Optional first message
|
||||||
Prefill(PrefillTokens),
|
Prefill(Tokens),
|
||||||
// Intermediate messages
|
// Intermediate messages
|
||||||
Intermediate {
|
Intermediate {
|
||||||
token: Token,
|
token: Token,
|
||||||
|
@ -44,7 +44,8 @@ impl Queue {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
@ -56,6 +57,7 @@ impl Queue {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
|
speculate,
|
||||||
queue_receiver,
|
queue_receiver,
|
||||||
));
|
));
|
||||||
|
|
||||||
@ -106,6 +108,7 @@ async fn queue_task(
|
|||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(
|
let mut state = State::new(
|
||||||
@ -113,7 +116,8 @@ async fn queue_task(
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
block_size,
|
block_size,
|
||||||
window_size
|
window_size,
|
||||||
|
speculate
|
||||||
);
|
);
|
||||||
|
|
||||||
while let Some(cmd) = receiver.recv().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
@ -256,6 +260,9 @@ struct State {
|
|||||||
|
|
||||||
/// Sliding window
|
/// Sliding window
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
|
|
||||||
|
/// Speculation amount
|
||||||
|
speculate: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
@ -265,6 +272,7 @@ impl State {
|
|||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let default_threshold: u64 = 120;
|
let default_threshold: u64 = 120;
|
||||||
let threshold: u64 = match env::var("QUEUE_THRESHOLD_MS") {
|
let threshold: u64 = match env::var("QUEUE_THRESHOLD_MS") {
|
||||||
@ -281,6 +289,7 @@ impl State {
|
|||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
|
speculate,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -365,7 +374,7 @@ impl State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
if prefill_tokens > prefill_token_budget
|
||||||
|| (prefill_tokens + decode_tokens) > token_budget
|
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||||
{
|
{
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
@ -457,13 +466,13 @@ mod tests {
|
|||||||
|
|
||||||
fn default_queue() -> Queue {
|
fn default_queue() -> Queue {
|
||||||
Queue::new(
|
Queue::new(
|
||||||
true, 1, 2, 1, None
|
true, 1, 2, 1, None, 0
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_state() -> State {
|
fn default_state() -> State {
|
||||||
State::new(
|
State::new(
|
||||||
true, 1, 2, 1, None
|
true, 1, 2, 1, None, 0
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -667,6 +676,25 @@ mod tests {
|
|||||||
assert_eq!(batch.size, 2);
|
assert_eq!(batch.size, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_token_speculate() {
|
||||||
|
let queue = Queue::new(true, 1, 2, 1, None, 2);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
queue.append(entry1);
|
||||||
|
queue.append(entry2);
|
||||||
|
|
||||||
|
// Budget of 1 is not enough
|
||||||
|
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||||
|
|
||||||
|
let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 2);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert!(entries.contains_key(&1));
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 2);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = default_queue();
|
let queue = default_queue();
|
||||||
|
@ -600,6 +600,7 @@ pub async fn run(
|
|||||||
max_input_length as u32,
|
max_input_length as u32,
|
||||||
max_total_tokens as u32,
|
max_total_tokens as u32,
|
||||||
shard_info.window_size,
|
shard_info.window_size,
|
||||||
|
shard_info.speculate,
|
||||||
generation_health,
|
generation_health,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -1,22 +1,25 @@
|
|||||||
build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git
|
vllm-cuda:
|
||||||
build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
|
||||||
build-vllm-cuda: BRANCH=main
|
|
||||||
build-vllm-cuda: build-vllm
|
|
||||||
|
|
||||||
build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git
|
|
||||||
build-vllm-rocm: VLLM_COMMIT=ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
|
|
||||||
build-vllm-rocm: BRANCH=rotary-no-positions-split-cos-sin
|
|
||||||
build-vllm-rocm: build-vllm
|
|
||||||
|
|
||||||
vllm:
|
|
||||||
# Clone vllm
|
# Clone vllm
|
||||||
pip install -U ninja packaging --no-cache-dir
|
pip install -U ninja packaging --no-cache-dir
|
||||||
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
|
git clone https://github.com/vllm-project/vllm.git vllm
|
||||||
|
|
||||||
build-vllm: vllm
|
build-vllm-cuda: vllm-cuda
|
||||||
cd vllm && git fetch && git checkout $(VLLM_COMMIT)
|
cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||||
cd vllm && python setup.py build
|
cd vllm && python setup.py build
|
||||||
|
|
||||||
install-vllm: build-vllm
|
install-vllm-cuda: build-vllm-cuda
|
||||||
|
pip uninstall vllm -y || true
|
||||||
|
cd vllm && python setup.py install
|
||||||
|
|
||||||
|
vllm-rocm:
|
||||||
|
# Clone vllm
|
||||||
|
pip install -U ninja packaging --no-cache-dir
|
||||||
|
git clone https://github.com/fxmarty/vllm-public.git vllm
|
||||||
|
|
||||||
|
build-vllm-rocm: vllm-rocm
|
||||||
|
cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
|
||||||
|
cd vllm && python setup.py build
|
||||||
|
|
||||||
|
install-vllm-rocm: build-vllm-rocm
|
||||||
pip uninstall vllm -y || true
|
pip uninstall vllm -y || true
|
||||||
cd vllm && python setup.py install
|
cd vllm && python setup.py install
|
||||||
|
@ -135,8 +135,8 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
|||||||
)
|
)
|
||||||
assert all([generation.generated_text is None for generation in generations])
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
assert all([generation.token_id.item() == 10264 for generation in generations])
|
assert all([token_id.item() == 10264 for generation in generations for token_id in generation.tokens.token_ids])
|
||||||
assert all([generation.token_text == "Test" for generation in generations])
|
assert all([token_text == "Test" for generation in generations for token_text in generation.tokens.texts])
|
||||||
assert generations[0].request_id == 0
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,8 +141,8 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
|||||||
)
|
)
|
||||||
assert all([generation.generated_text is None for generation in generations])
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
assert all([generation.token_id.item() == 13 for generation in generations])
|
assert all([token_id.item() == 13 for generation in generations for token_id in generation.tokens.token_ids])
|
||||||
assert all([generation.token_text == "." for generation in generations])
|
assert all([token_text == "." for generation in generations for token_text in generation.tokens.texts])
|
||||||
assert generations[0].request_id == 0
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -155,8 +155,8 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
|||||||
)
|
)
|
||||||
assert all([generation.generated_text is None for generation in generations])
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
assert all([generation.token_id.item() == 259 for generation in generations])
|
assert all([token_id.item() == 259 for generation in generations for token_id in generation.tokens.token_ids])
|
||||||
assert all([generation.token_text == " " for generation in generations])
|
assert all([token_text == " " for generation in generations for token_text in generation.tokens.texts])
|
||||||
assert generations[0].request_id == 0
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from pathlib import Path
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
@ -31,6 +32,7 @@ def serve(
|
|||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[Quantization] = None,
|
quantize: Optional[Quantization] = None,
|
||||||
|
speculate: Optional[int] = None,
|
||||||
dtype: Optional[Dtype] = None,
|
dtype: Optional[Dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
uds_path: Path = "/tmp/text-generation-server",
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
@ -39,9 +41,15 @@ def serve(
|
|||||||
otlp_endpoint: Optional[str] = None,
|
otlp_endpoint: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if sharded:
|
if sharded:
|
||||||
assert os.getenv("WORLD_SIZE", None) is not None, "WORLD_SIZE must be set when sharded is True"
|
assert (
|
||||||
assert os.getenv("MASTER_ADDR", None) is not None, "MASTER_ADDR must be set when sharded is True"
|
os.getenv("WORLD_SIZE", None) is not None
|
||||||
assert os.getenv("MASTER_PORT", None) is not None, "MASTER_PORT must be set when sharded is True"
|
), "WORLD_SIZE must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("MASTER_ADDR", None) is not None
|
||||||
|
), "MASTER_ADDR must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("MASTER_PORT", None) is not None
|
||||||
|
), "MASTER_PORT must be set when sharded is True"
|
||||||
|
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
logger.remove()
|
logger.remove()
|
||||||
@ -75,7 +83,11 @@ def serve(
|
|||||||
logger.info("CLI SHARDED = {}".format(num_shard))
|
logger.info("CLI SHARDED = {}".format(num_shard))
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file} --model_id {model_id} --revision {revision} --sharded {sharded} --dtype {dtype} --uds_path {uds_path}"
|
cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
|
||||||
|
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
|
||||||
|
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
|
||||||
|
if speculate is not None:
|
||||||
|
cmd += f"--speculate {speculate}"
|
||||||
logger.info("CLI server start deepspeed ={} ".format(cmd))
|
logger.info("CLI server start deepspeed ={} ".format(cmd))
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
@ -119,7 +131,9 @@ def serve(
|
|||||||
logger.error(f"{cmd} exited with status = {proc.returncode}")
|
logger.error(f"{cmd} exited with status = {proc.returncode}")
|
||||||
return proc.returncode
|
return proc.returncode
|
||||||
else:
|
else:
|
||||||
server.serve(model_id, revision, dtype, uds_path, sharded)
|
server.serve(
|
||||||
|
model_id, revision, sharded, speculate, dtype, trust_remote_code, uds_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
@ -153,7 +167,7 @@ def download_weights(
|
|||||||
logger.info("Files are already present on the host. " "Skipping download.")
|
logger.info("Files are already present on the host. " "Skipping download.")
|
||||||
return
|
return
|
||||||
# Local files not found
|
# Local files not found
|
||||||
except (utils.LocalEntryNotFoundError, FileNotFoundError):
|
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
|
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
|
||||||
@ -161,6 +175,42 @@ def download_weights(
|
|||||||
) is not None
|
) is not None
|
||||||
|
|
||||||
if not is_local_model:
|
if not is_local_model:
|
||||||
|
try:
|
||||||
|
adapter_config_filename = hf_hub_download(
|
||||||
|
model_id, revision=revision, filename="adapter_config.json"
|
||||||
|
)
|
||||||
|
utils.download_and_unload_peft(
|
||||||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
is_local_model = True
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
return
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")
|
||||||
|
if auto_convert:
|
||||||
|
medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors")
|
||||||
|
if not medusa_sf.exists():
|
||||||
|
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
||||||
|
medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json")
|
||||||
|
with open(medusa_config, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
model_id = config["base_model_name_or_path"]
|
||||||
|
revision = "main"
|
||||||
|
try:
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.")
|
||||||
|
return
|
||||||
|
# Local files not found
|
||||||
|
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
# Try to download weights from the hub
|
# Try to download weights from the hub
|
||||||
try:
|
try:
|
||||||
filenames = utils.weight_hub_files(model_id, revision, extension)
|
filenames = utils.weight_hub_files(model_id, revision, extension)
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto
|
||||||
from transformers import AutoConfig
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# Needed to properly setup habana_frameworks
|
||||||
|
import text_generation_server.habana_quantization_env as hq_env
|
||||||
|
|
||||||
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
from text_generation_server.models.bloom import BLOOM
|
from text_generation_server.models.bloom import BLOOM
|
||||||
@ -18,10 +22,46 @@ torch.set_grad_enabled(False)
|
|||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
dtype: Optional[torch.dtype] = None,
|
speculate: Optional[int],
|
||||||
|
dtype: Optional[torch.dtype],
|
||||||
|
trust_remote_code: bool,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
if speculate is not None:
|
||||||
model_type = config.model_type
|
set_speculate(speculate)
|
||||||
|
else:
|
||||||
|
set_speculate(0)
|
||||||
|
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
|
use_medusa = None
|
||||||
|
if "medusa_num_heads" in config_dict:
|
||||||
|
use_medusa = model_id
|
||||||
|
medusa_config = config_dict
|
||||||
|
model_id = config_dict["base_model_name_or_path"]
|
||||||
|
revision = "main"
|
||||||
|
speculate_medusa = config_dict["medusa_num_heads"]
|
||||||
|
if speculate is not None:
|
||||||
|
if speculate > speculate_medusa:
|
||||||
|
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
|
||||||
|
else:
|
||||||
|
set_speculate(speculate)
|
||||||
|
else:
|
||||||
|
set_speculate(speculate_medusa)
|
||||||
|
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
method = "medusa"
|
||||||
|
else:
|
||||||
|
method = "n-gram"
|
||||||
|
|
||||||
|
speculate = get_speculate()
|
||||||
|
if speculate > 0:
|
||||||
|
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
||||||
|
|
||||||
|
model_type = config_dict["model_type"]
|
||||||
|
|
||||||
if model_type == "gpt_bigcode":
|
if model_type == "gpt_bigcode":
|
||||||
return SantaCoder(model_id, revision, dtype)
|
return SantaCoder(model_id, revision, dtype)
|
||||||
|
@ -35,10 +35,9 @@ from text_generation_server.utils.tokens import batch_top_tokens
|
|||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
PrefillTokens,
|
Tokens,
|
||||||
Generation,
|
Generation,
|
||||||
GeneratedText,
|
GeneratedText,
|
||||||
TopTokens,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
@ -48,6 +47,7 @@ from text_generation_server.utils import (
|
|||||||
is_tokenizer_transparent,
|
is_tokenizer_transparent,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.debug import dbg_trace
|
from text_generation_server.utils.debug import dbg_trace
|
||||||
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -647,6 +647,8 @@ class CausalLM(Model):
|
|||||||
kwargs["attn_softmax_bf16"] = True
|
kwargs["attn_softmax_bf16"] = True
|
||||||
kwargs["trim_logits"] = True
|
kwargs["trim_logits"] = True
|
||||||
|
|
||||||
|
self.speculate = get_speculate()
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -842,12 +844,12 @@ class CausalLM(Model):
|
|||||||
# Select next token
|
# Select next token
|
||||||
input_length = batch.input_length
|
input_length = batch.input_length
|
||||||
if logits.shape[-2] > 1:
|
if logits.shape[-2] > 1:
|
||||||
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
|
||||||
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2)
|
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
|
||||||
batch.input_ids, logits.squeeze(-2)
|
batch.input_ids, logits.squeeze(-2), self.speculate
|
||||||
)
|
)
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
@ -1017,7 +1019,9 @@ class CausalLM(Model):
|
|||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
prefill_tokens = PrefillTokens(prefill_token_ids, prefill_logprobs, prefill_texts)
|
prefill_tokens = Tokens(
|
||||||
|
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
@ -1027,8 +1031,10 @@ class CausalLM(Model):
|
|||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
special_toptokens = [token_id in self.all_special_ids for token_id in top_token_ids]
|
special_toptokens = [
|
||||||
top_tokens = TopTokens(
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
|
]
|
||||||
|
top_tokens = Tokens(
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
toptoken_texts,
|
toptoken_texts,
|
||||||
@ -1040,10 +1046,12 @@ class CausalLM(Model):
|
|||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
next_token_id,
|
Tokens(
|
||||||
next_token_logprob,
|
[next_token_id],
|
||||||
next_token_text,
|
[next_token_logprob],
|
||||||
next_token_id in self.all_special_ids,
|
[next_token_text],
|
||||||
|
[next_token_id in self.all_special_ids],
|
||||||
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
@ -11,13 +11,13 @@ from opentelemetry import trace
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
PrefillTokens,
|
Tokens,
|
||||||
Generation,
|
Generation,
|
||||||
GeneratedText,
|
GeneratedText,
|
||||||
TopTokens,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.models.cache_manager import (
|
from text_generation_server.models.cache_manager import (
|
||||||
get_cache_manager,
|
get_cache_manager,
|
||||||
@ -41,6 +41,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Decoder values
|
# Decoder values
|
||||||
input_ids: torch.Tensor
|
input_ids: torch.Tensor
|
||||||
position_ids: torch.Tensor
|
position_ids: torch.Tensor
|
||||||
|
speculative_ids: torch.Tensor
|
||||||
|
|
||||||
# Flash Attention values
|
# Flash Attention values
|
||||||
|
|
||||||
@ -120,6 +121,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
|
|
||||||
position_ids = []
|
position_ids = []
|
||||||
|
speculative_ids = []
|
||||||
cu_seqlen_prefill = [0]
|
cu_seqlen_prefill = [0]
|
||||||
needed_blocks_slots = []
|
needed_blocks_slots = []
|
||||||
start_slots = []
|
start_slots = []
|
||||||
@ -163,6 +165,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
prefix_offsets.append(input_length - 5)
|
prefix_offsets.append(input_length - 5)
|
||||||
read_offsets.append(input_length)
|
read_offsets.append(input_length)
|
||||||
|
|
||||||
@ -186,7 +190,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Paged attention
|
# Paged attention
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
total_tokens = input_length + max_new_tokens - 1
|
speculative_length = get_speculate()
|
||||||
|
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||||
blocks += needed_blocks
|
blocks += needed_blocks
|
||||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||||
@ -224,7 +229,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cumulative_max_length += total_tokens
|
cumulative_max_length += total_tokens
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
max_blocks = max(max_blocks, needed_blocks)
|
max_blocks = max(max_blocks, needed_blocks)
|
||||||
max_length = max(max_length, input_length + max_new_tokens)
|
max_length = max(max_length, input_length + max_new_tokens + speculative_length)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device
|
next_token_chooser_parameters, dtype, device
|
||||||
@ -255,7 +260,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cu_seqlen_prefill = torch.tensor(
|
cu_seqlen_prefill = torch.tensor(
|
||||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
|
||||||
position_ids = position_ids.to(device)
|
position_ids = position_ids.to(device)
|
||||||
slot_indices = slot_indices.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
@ -309,6 +313,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
|
speculative_ids=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
@ -419,6 +424,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
|
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
|
||||||
|
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
@ -454,6 +460,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
|
speculative_ids=speculative_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -473,6 +480,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
total_batch_size += len(b)
|
total_batch_size += len(b)
|
||||||
total_slots += len(b.slots)
|
total_slots += len(b.slots)
|
||||||
blocks += b.blocks
|
blocks += b.blocks
|
||||||
|
speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
||||||
max_blocks = max(max_blocks, b.max_blocks)
|
max_blocks = max(max_blocks, b.max_blocks)
|
||||||
max_seqlen = max(max_seqlen, b.max_seqlen)
|
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||||
max_length = max(
|
max_length = max(
|
||||||
@ -480,6 +488,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max(
|
max(
|
||||||
input_length
|
input_length
|
||||||
+ stopping_criteria.max_new_tokens
|
+ stopping_criteria.max_new_tokens
|
||||||
|
+ speculative_length
|
||||||
- stopping_criteria.current_tokens
|
- stopping_criteria.current_tokens
|
||||||
for input_length, stopping_criteria in zip(
|
for input_length, stopping_criteria in zip(
|
||||||
b.input_lengths, b.stopping_criterias
|
b.input_lengths, b.stopping_criterias
|
||||||
@ -577,6 +586,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
device=batches[0].next_token_chooser.device,
|
device=batches[0].next_token_chooser.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None
|
||||||
|
|
||||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||||
for b in batches:
|
for b in batches:
|
||||||
b.block_tables = None
|
b.block_tables = None
|
||||||
@ -611,6 +622,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
|
speculative_ids=speculative_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
@ -714,16 +726,55 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
|
if batch.speculative_ids is not None:
|
||||||
|
input_ids=batch.input_ids
|
||||||
|
position_ids=batch.position_ids
|
||||||
|
cu_seqlen_prefill=batch.cu_seqlen_prefill
|
||||||
|
kv_cache=get_cache_manager().kv_cache
|
||||||
|
block_tables=batch.block_tables_tensor
|
||||||
|
slots=batch.slots[batch.slot_indices]
|
||||||
|
input_lengths=batch.input_lengths_tensor
|
||||||
|
max_s=batch.max_seqlen
|
||||||
|
lm_head_indices=batch.prefill_head_indices
|
||||||
|
|
||||||
|
speculative_ids = batch.speculative_ids
|
||||||
|
|
||||||
|
B, speculative_length = speculative_ids.shape
|
||||||
|
new_length = speculative_length + 1
|
||||||
|
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
|
||||||
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||||
|
arange_int = arange.to(dtype=torch.int32)
|
||||||
|
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
|
||||||
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
|
||||||
|
# Add Copy the block tables for all members
|
||||||
|
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
|
||||||
|
max_s = max_s + speculative_length
|
||||||
|
|
||||||
|
input_ids = new_input_ids
|
||||||
|
position_ids = new_position_ids
|
||||||
|
else:
|
||||||
|
input_ids=batch.input_ids
|
||||||
|
position_ids=batch.position_ids
|
||||||
|
cu_seqlen_prefill=batch.cu_seqlen_prefill
|
||||||
|
kv_cache=get_cache_manager().kv_cache
|
||||||
|
block_tables=batch.block_tables_tensor
|
||||||
|
slots=batch.slots[batch.slot_indices]
|
||||||
|
input_lengths=batch.input_lengths_tensor
|
||||||
|
max_s=batch.max_seqlen
|
||||||
|
lm_head_indices=batch.prefill_head_indices
|
||||||
|
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids=batch.input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=batch.position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=batch.cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=get_cache_manager().kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=batch.block_tables_tensor,
|
block_tables=block_tables,
|
||||||
slots=batch.slots[batch.slot_indices],
|
slots=slots,
|
||||||
input_lengths=batch.input_lengths_tensor,
|
input_lengths=input_lengths,
|
||||||
max_s=batch.max_seqlen,
|
max_s=max_s,
|
||||||
lm_head_indices=batch.prefill_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
@ -752,21 +803,32 @@ class FlashCausalLM(Model):
|
|||||||
del batch
|
del batch
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
if isinstance(out, tuple):
|
||||||
|
out, speculative_logits = out
|
||||||
|
else:
|
||||||
|
speculative_logits = None
|
||||||
|
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
next_token_logits = (
|
next_token_logits = (
|
||||||
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
||||||
)
|
)
|
||||||
|
if speculative_logits is not None:
|
||||||
|
speculative_logits = (
|
||||||
|
speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
|
|
||||||
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
|
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
|
||||||
if prefill:
|
if prefill:
|
||||||
if len(batch) > 1 and prefill_logprobs:
|
if len(batch) > 1 and prefill_logprobs:
|
||||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||||
@ -792,6 +854,7 @@ class FlashCausalLM(Model):
|
|||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
|
accepted_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
||||||
@ -799,9 +862,11 @@ class FlashCausalLM(Model):
|
|||||||
# It is faster if we delay this sync for the maximum amount of time
|
# It is faster if we delay this sync for the maximum amount of time
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
|
index = 0
|
||||||
for i, (
|
for i, (
|
||||||
input_length,
|
input_length,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
|
n_accepted_ids
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
start_index = cumulative_length
|
start_index = cumulative_length
|
||||||
@ -830,15 +895,18 @@ class FlashCausalLM(Model):
|
|||||||
start_index + 1 : start_index + out_length
|
start_index + 1 : start_index + out_length
|
||||||
]
|
]
|
||||||
|
|
||||||
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
|
for j in range(n_accepted_ids):
|
||||||
|
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
|
||||||
|
index += 1
|
||||||
|
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
||||||
# Set values in batch
|
|
||||||
batch.input_ids = next_input_ids
|
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||||
batch.position_ids = next_position_ids + 1
|
batch.speculative_ids = speculative_ids
|
||||||
batch.input_lengths_tensor += 1
|
batch.position_ids = next_position_ids + accepted_ids
|
||||||
batch.slot_indices += 1
|
batch.input_lengths_tensor += accepted_ids
|
||||||
|
batch.slot_indices += accepted_ids
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
if prefill and prefill_logprobs:
|
||||||
# Get prefill logprobs
|
# Get prefill logprobs
|
||||||
@ -851,7 +919,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# GPU <-> CPU sync
|
# GPU <-> CPU sync
|
||||||
next_token_logprobs = next_token_logprobs.tolist()
|
next_token_logprobs = next_token_logprobs.tolist()
|
||||||
next_token_ids = batch.input_ids.tolist()
|
next_token_ids = next_input_ids.tolist()
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
@ -864,13 +932,13 @@ class FlashCausalLM(Model):
|
|||||||
batch.next_token_chooser.do_sample,
|
batch.next_token_chooser.do_sample,
|
||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
next_token_ids,
|
accepted_ids,
|
||||||
next_token_logprobs,
|
|
||||||
batch_top_token_ids,
|
batch_top_token_ids,
|
||||||
batch_top_token_logprobs,
|
batch_top_token_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
|
index = 0
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
@ -881,29 +949,43 @@ class FlashCausalLM(Model):
|
|||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
next_token_id,
|
n_accepted_ids,
|
||||||
next_token_logprob,
|
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids.append(next_token_id)
|
next_token_texts = []
|
||||||
|
left = 0
|
||||||
|
before = stopping_criteria.current_tokens
|
||||||
|
|
||||||
# Generated token
|
current_stopped = False
|
||||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
for j in range(index, index + n_accepted_ids):
|
||||||
all_input_ids,
|
# Generated token
|
||||||
prefix_offset,
|
next_token_id = next_token_ids[j]
|
||||||
read_offset,
|
all_input_ids.append(next_token_id)
|
||||||
)
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
|
all_input_ids,
|
||||||
|
prefix_offset,
|
||||||
|
read_offset,
|
||||||
|
)
|
||||||
|
next_token_texts.append(next_token_text)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
stop, reason = stopping_criteria(
|
||||||
stop, reason = stopping_criteria(
|
next_token_id,
|
||||||
next_token_id,
|
next_token_text,
|
||||||
next_token_text,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if not stop:
|
if stop:
|
||||||
stopped = False
|
left = index + n_accepted_ids - j - 1
|
||||||
|
current_stopped = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
current_stopped = False
|
||||||
|
stopped = stopped and current_stopped
|
||||||
|
|
||||||
|
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
|
||||||
|
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
|
||||||
|
index += n_accepted_ids
|
||||||
|
|
||||||
# Shard generations
|
# Shard generations
|
||||||
# All generations will be appended in the rust sharded client
|
# All generations will be appended in the rust sharded client
|
||||||
@ -943,8 +1025,9 @@ class FlashCausalLM(Model):
|
|||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
prefill_tokens = PrefillTokens(
|
|
||||||
prefill_token_ids, request_prefill_logprobs, prefill_texts
|
prefill_tokens = Tokens(
|
||||||
|
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
@ -958,7 +1041,7 @@ class FlashCausalLM(Model):
|
|||||||
special_toptokens = [
|
special_toptokens = [
|
||||||
token_id in self.all_special_ids for token_id in top_token_ids
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
]
|
]
|
||||||
top_tokens = TopTokens(
|
top_tokens = Tokens(
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
toptoken_texts,
|
toptoken_texts,
|
||||||
@ -970,10 +1053,12 @@ class FlashCausalLM(Model):
|
|||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
next_token_id,
|
Tokens(
|
||||||
next_token_logprob,
|
_next_token_ids,
|
||||||
next_token_text,
|
_next_token_logprobs,
|
||||||
next_token_id in self.all_special_ids,
|
next_token_texts,
|
||||||
|
[nid in self.all_special_ids for nid in _next_token_ids],
|
||||||
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
)
|
)
|
||||||
@ -981,7 +1066,9 @@ class FlashCausalLM(Model):
|
|||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_lengths[i] = input_length + 1
|
batch.input_lengths[i] = input_length + n_accepted_ids.item()
|
||||||
|
if batch.input_lengths[i] > batch.max_seqlen:
|
||||||
|
batch.max_seqlen = batch.input_lengths[i]
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.read_offsets[i] = read_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
@ -994,6 +1081,5 @@ class FlashCausalLM(Model):
|
|||||||
batch.prefill_cu_outlens = None
|
batch.prefill_cu_outlens = None
|
||||||
batch.prefill_head_indices = None
|
batch.prefill_head_indices = None
|
||||||
batch.prefill_next_token_indices = None
|
batch.prefill_next_token_indices = None
|
||||||
batch.max_seqlen = batch.max_seqlen + 1
|
|
||||||
|
|
||||||
return generations, batch
|
return generations, batch
|
||||||
|
@ -28,6 +28,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -66,6 +67,18 @@ class FlashLlama(FlashCausalLM):
|
|||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id)
|
||||||
|
|
||||||
model = FlashLlamaForCausalLM(config, weights)
|
model = FlashLlamaForCausalLM(config, weights)
|
||||||
|
if use_medusa:
|
||||||
|
from text_generation_server.utils.medusa import MedusaModel
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
import json
|
||||||
|
medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json")
|
||||||
|
with open(medusa_config, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt")
|
||||||
|
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors"
|
||||||
|
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group)
|
||||||
|
lm_head = model.lm_head
|
||||||
|
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashLlama, self).__init__(
|
super(FlashLlama, self).__init__(
|
||||||
|
@ -21,6 +21,7 @@ from text_generation_server.models.custom_modeling.flash_mistral_modeling import
|
|||||||
FlashMistralForCausalLM,
|
FlashMistralForCausalLM,
|
||||||
MistralConfig,
|
MistralConfig,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
@ -132,7 +133,8 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
# Paged attention
|
# Paged attention
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
total_tokens = input_length + max_new_tokens - 1
|
speculative_length = get_speculate()
|
||||||
|
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||||
|
|
||||||
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
||||||
needed_blocks = min(
|
needed_blocks = min(
|
||||||
@ -183,7 +185,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
cumulative_max_length += total_tokens
|
cumulative_max_length += total_tokens
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
max_blocks = max(max_blocks, needed_blocks)
|
max_blocks = max(max_blocks, needed_blocks)
|
||||||
max_length = max(max_length, input_length + max_new_tokens)
|
max_length = max(max_length, input_length + max_new_tokens + speculative_length)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device
|
next_token_chooser_parameters, dtype, device
|
||||||
@ -272,6 +274,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
|
speculative_ids=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -340,17 +343,55 @@ class FlashMistral(FlashCausalLM):
|
|||||||
|
|
||||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
|
if batch.speculative_ids is not None:
|
||||||
|
input_ids=batch.input_ids
|
||||||
|
position_ids=batch.position_ids
|
||||||
|
cu_seqlen_prefill=batch.cu_seqlen_prefill
|
||||||
|
kv_cache=get_cache_manager().kv_cache
|
||||||
|
block_tables=batch.block_tables_tensor
|
||||||
|
slots=batch.slots[batch.slot_indices]
|
||||||
|
input_lengths=batch.input_lengths_tensor
|
||||||
|
max_s=batch.max_seqlen
|
||||||
|
lm_head_indices=batch.prefill_head_indices
|
||||||
|
|
||||||
|
speculative_ids = batch.speculative_ids
|
||||||
|
|
||||||
|
B, speculative_length = speculative_ids.shape
|
||||||
|
new_length = speculative_length + 1
|
||||||
|
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
|
||||||
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||||
|
arange_int = arange.to(dtype=torch.int32)
|
||||||
|
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
|
||||||
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
|
||||||
|
# Add Copy the block tables for all members
|
||||||
|
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
|
||||||
|
max_s = max_s + speculative_length
|
||||||
|
|
||||||
|
input_ids = new_input_ids
|
||||||
|
position_ids = new_position_ids
|
||||||
|
else:
|
||||||
|
input_ids=batch.input_ids
|
||||||
|
position_ids=batch.position_ids
|
||||||
|
cu_seqlen_prefill=batch.cu_seqlen_prefill
|
||||||
|
kv_cache=get_cache_manager().kv_cache
|
||||||
|
block_tables=batch.block_tables_tensor
|
||||||
|
slots=batch.slots[batch.slot_indices]
|
||||||
|
input_lengths=batch.input_lengths_tensor
|
||||||
|
max_s=batch.max_seqlen
|
||||||
|
lm_head_indices=batch.prefill_head_indices
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
input_ids=batch.input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=batch.position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=batch.cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=get_cache_manager().kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=batch.block_tables_tensor,
|
block_tables=block_tables,
|
||||||
slots=batch.slots[batch.slot_indices],
|
slots=slots,
|
||||||
input_lengths=batch.input_lengths_tensor,
|
input_lengths=input_lengths,
|
||||||
max_s=batch.max_seqlen,
|
max_s=max_s,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=batch.prefill_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
|
@ -20,7 +20,7 @@ from typing import Optional, Tuple, List, Type, Dict
|
|||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
PrefillTokens,
|
Tokens,
|
||||||
Generation,
|
Generation,
|
||||||
GeneratedText,
|
GeneratedText,
|
||||||
)
|
)
|
||||||
@ -791,8 +791,8 @@ class IdeficsCausalLM(Model):
|
|||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
prefill_tokens = PrefillTokens(
|
prefill_tokens = Tokens(
|
||||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
@ -802,10 +802,12 @@ class IdeficsCausalLM(Model):
|
|||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
next_token_id_squeezed,
|
Tokens(
|
||||||
next_token_logprob,
|
[next_token_id_squeezed],
|
||||||
next_token_text,
|
[next_token_logprob],
|
||||||
next_token_id_squeezed.item() in self.all_special_ids,
|
[next_token_text],
|
||||||
|
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||||
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
@ -5,7 +5,8 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import List, Optional, Tuple, Type, TypeVar
|
from typing import List, Optional, Tuple, Type, TypeVar
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch, GeneratedText
|
from text_generation_server.models.types import Batch, Generation
|
||||||
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
@ -22,6 +23,7 @@ class Model(ABC):
|
|||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
kwargs: dict = {},
|
kwargs: dict = {},
|
||||||
|
speculate: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
@ -32,7 +34,14 @@ class Model(ABC):
|
|||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self.has_position_ids = inspect.signature(model.forward).parameters.get("position_ids", None) is not None
|
if speculate is None:
|
||||||
|
speculate = get_speculate()
|
||||||
|
self.speculate = speculate
|
||||||
|
|
||||||
|
self.has_position_ids = (
|
||||||
|
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
self.check_initialized()
|
self.check_initialized()
|
||||||
|
|
||||||
@ -42,6 +51,7 @@ class Model(ABC):
|
|||||||
requires_padding=self.requires_padding,
|
requires_padding=self.requires_padding,
|
||||||
dtype=str(self.dtype),
|
dtype=str(self.dtype),
|
||||||
device_type=self.device.type,
|
device_type=self.device.type,
|
||||||
|
speculate=self.speculate
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -50,7 +60,7 @@ class Model(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def warmup(self, batch: B, max_total_tokens: int):
|
def warmup(self, batch: B, max_total_tokens: int):
|
||||||
|
@ -11,8 +11,7 @@ from text_generation_server.models.types import (
|
|||||||
GeneratedText,
|
GeneratedText,
|
||||||
Batch,
|
Batch,
|
||||||
Generation,
|
Generation,
|
||||||
PrefillTokens,
|
Tokens,
|
||||||
TopTokens,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
@ -733,10 +732,11 @@ class Seq2SeqLM(Model):
|
|||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||||
prefill_tokens = PrefillTokens(
|
prefill_tokens = Tokens(
|
||||||
[self.tokenizer.bos_token_id],
|
[self.tokenizer.bos_token_id],
|
||||||
[float("nan")],
|
[float("nan")],
|
||||||
[self.tokenizer.bos_token],
|
[self.tokenizer.bos_token],
|
||||||
|
[False]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
@ -750,7 +750,7 @@ class Seq2SeqLM(Model):
|
|||||||
special_toptokens = [
|
special_toptokens = [
|
||||||
token_id in self.all_special_ids for token_id in top_token_ids
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
]
|
]
|
||||||
top_tokens = TopTokens(
|
top_tokens = Tokens(
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
toptoken_texts,
|
toptoken_texts,
|
||||||
@ -762,10 +762,12 @@ class Seq2SeqLM(Model):
|
|||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
next_token_id_squeezed,
|
Tokens(
|
||||||
next_token_logprob,
|
[next_token_id_squeezed],
|
||||||
next_token_text,
|
[next_token_logprob],
|
||||||
next_token_id_squeezed.item() in self.all_special_ids,
|
[next_token_text],
|
||||||
|
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||||
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
@ -58,33 +58,15 @@ class GeneratedText:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PrefillTokens:
|
class Tokens:
|
||||||
token_ids: List[int]
|
|
||||||
logprobs: List[float]
|
|
||||||
texts: List[str]
|
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.PrefillTokens:
|
|
||||||
return generate_pb2.PrefillTokens(
|
|
||||||
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.token_ids)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TopTokens:
|
|
||||||
token_ids: List[int]
|
token_ids: List[int]
|
||||||
logprobs: List[float]
|
logprobs: List[float]
|
||||||
texts: List[str]
|
texts: List[str]
|
||||||
is_special: List[bool]
|
is_special: List[bool]
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.TopTokens:
|
def to_pb(self) -> generate_pb2.Tokens:
|
||||||
return generate_pb2.TopTokens(
|
return generate_pb2.Tokens(
|
||||||
ids=self.token_ids,
|
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special
|
||||||
logprobs=self.logprobs,
|
|
||||||
texts=self.texts,
|
|
||||||
is_special=self.is_special,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -94,14 +76,11 @@ class TopTokens:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Generation:
|
class Generation:
|
||||||
request_id: int
|
request_id: int
|
||||||
prefill_tokens: Optional[PrefillTokens]
|
prefill_tokens: Optional[Tokens]
|
||||||
token_id: int
|
tokens: Tokens
|
||||||
token_logprob: float
|
|
||||||
token_text: str
|
|
||||||
token_is_special: bool
|
|
||||||
generated_text: Optional[GeneratedText]
|
generated_text: Optional[GeneratedText]
|
||||||
# Optional for now, since it's not yet supported for every model.
|
# Optional for now, since it's not yet supported for every model.
|
||||||
top_tokens: Optional[TopTokens]
|
top_tokens: Optional[List[Tokens]]
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Generation:
|
def to_pb(self) -> generate_pb2.Generation:
|
||||||
return generate_pb2.Generation(
|
return generate_pb2.Generation(
|
||||||
@ -109,10 +88,7 @@ class Generation:
|
|||||||
prefill_tokens=self.prefill_tokens.to_pb()
|
prefill_tokens=self.prefill_tokens.to_pb()
|
||||||
if self.prefill_tokens is not None
|
if self.prefill_tokens is not None
|
||||||
else None,
|
else None,
|
||||||
token_id=self.token_id,
|
tokens=self.tokens.to_pb(),
|
||||||
token_logprob=self.token_logprob,
|
|
||||||
token_text=self.token_text,
|
|
||||||
token_is_special=self.token_is_special,
|
|
||||||
generated_text=self.generated_text.to_pb()
|
generated_text=self.generated_text.to_pb()
|
||||||
if self.generated_text is not None
|
if self.generated_text is not None
|
||||||
else None,
|
else None,
|
||||||
|
@ -107,9 +107,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
dtype: Optional[str],
|
|
||||||
uds_path: Path,
|
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
|
speculate: Optional[int],
|
||||||
|
dtype: Optional[str],
|
||||||
|
trust_remote_code: bool,
|
||||||
|
uds_path: Path,
|
||||||
):
|
):
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
logger.remove()
|
logger.remove()
|
||||||
@ -126,8 +128,10 @@ def serve(
|
|||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
dtype: Optional[str] = None,
|
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
|
speculate: Optional[int] = None,
|
||||||
|
dtype: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
logger.info("Server:server_inner: sharded ={}".format(sharded))
|
logger.info("Server:server_inner: sharded ={}".format(sharded))
|
||||||
@ -151,7 +155,9 @@ def serve(
|
|||||||
if revision == "None":
|
if revision == "None":
|
||||||
revision = None
|
revision = None
|
||||||
try:
|
try:
|
||||||
model = get_model(model_id, revision=revision, dtype=data_type)
|
model = get_model(
|
||||||
|
model_id, revision, speculate, dtype=data_type, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
raise
|
raise
|
||||||
@ -181,13 +187,7 @@ def serve(
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Signal received. Shutting down")
|
logger.info("Signal received. Shutting down")
|
||||||
await server.stop(0)
|
await server.stop(0)
|
||||||
finally:
|
|
||||||
if hasattr(model,'finish_quantization_measurements'):
|
|
||||||
model.finish_quantization_measurements()
|
|
||||||
|
|
||||||
logger.info(
|
asyncio.run(
|
||||||
"Starting Server : model_id= {}, revision = {} dtype = {} sharded = {} ".format(
|
serve_inner(model_id, revision, sharded, speculate, dtype, trust_remote_code)
|
||||||
model_id, revision, dtype, sharded
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
asyncio.run(serve_inner(model_id, revision, dtype, sharded))
|
|
||||||
|
@ -9,12 +9,18 @@ import argparse
|
|||||||
def main(args):
|
def main(args):
|
||||||
logger.info("TGIService: starting tgi service .... ")
|
logger.info("TGIService: starting tgi service .... ")
|
||||||
logger.info(
|
logger.info(
|
||||||
"TGIService: --model_id {}, --revision {}, --sharded {}, --dtype {}, --uds_path {} ".format(
|
"TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format(
|
||||||
args.model_id, args.revision, args.sharded, args.dtype, args.uds_path
|
args.model_id, args.revision, args.sharded, args.speculate, args.dtype, args.trust_remote_code, args.uds_path
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
server.serve(
|
server.serve(
|
||||||
model_id=args.model_id, revision=args.revision, dtype=args.dtype, uds_path=args.uds_path, sharded=args.sharded
|
model_id=args.model_id,
|
||||||
|
revision=args.revision,
|
||||||
|
sharded=args.sharded,
|
||||||
|
speculate=args.speculate,
|
||||||
|
dtype=args.dtype,
|
||||||
|
trust_remote_code=args.trust_remote_code,
|
||||||
|
uds_path=args.uds_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +29,9 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--model_id", type=str)
|
parser.add_argument("--model_id", type=str)
|
||||||
parser.add_argument("--revision", type=str)
|
parser.add_argument("--revision", type=str)
|
||||||
parser.add_argument("--sharded", type=bool)
|
parser.add_argument("--sharded", type=bool)
|
||||||
|
parser.add_argument("--speculate", type=int, default=None)
|
||||||
parser.add_argument("--dtype", type=str)
|
parser.add_argument("--dtype", type=str)
|
||||||
|
parser.add_argument("--trust_remote_code", type=bool)
|
||||||
parser.add_argument("--uds_path", type=Path)
|
parser.add_argument("--uds_path", type=Path)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
51
server/text_generation_server/utils/medusa.py
Normal file
51
server/text_generation_server/utils/medusa.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import torch
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Output:
|
||||||
|
logits: torch.FloatTensor = None
|
||||||
|
speculative_logits: torch.FloatTensor = None
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True)
|
||||||
|
self.act = torch.nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + self.act(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
lm_head
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = torch.nn.ModuleList(
|
||||||
|
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])]
|
||||||
|
)
|
||||||
|
self.lm_head = lm_head
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
logits = self.lm_head(x)
|
||||||
|
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaHead(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])])
|
||||||
|
n = len(self.blocks)
|
||||||
|
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
x = self.out(x)
|
||||||
|
return x
|
12
server/text_generation_server/utils/speculate.py
Normal file
12
server/text_generation_server/utils/speculate.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
|
||||||
|
SPECULATE = None
|
||||||
|
|
||||||
|
def get_speculate() -> int:
|
||||||
|
global SPECULATE
|
||||||
|
return SPECULATE
|
||||||
|
|
||||||
|
def set_speculate(speculate: int):
|
||||||
|
global SPECULATE
|
||||||
|
SPECULATE = speculate
|
||||||
|
|
||||||
|
|
@ -142,6 +142,22 @@ class StoppingCriteria:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
|
||||||
|
# Very trivial approach, find first match in the string.
|
||||||
|
# This is much less refined than actual n-gram but seems to work
|
||||||
|
# relatively OK in grounded mode and is by far much faster with
|
||||||
|
# much less worst case complexity as everything happens on device.
|
||||||
|
B = accepted_ids.shape[0]
|
||||||
|
device = input_ids.device
|
||||||
|
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
|
||||||
|
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
|
||||||
|
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
|
||||||
|
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
|
||||||
|
|
||||||
|
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
|
||||||
|
return speculative_ids
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousNextTokenChooser:
|
class HeterogeneousNextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -206,16 +222,72 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
def __call__(
|
||||||
if self.watermark_processor is not None:
|
self,
|
||||||
scores = self.watermark_processor(input_ids, scores)
|
input_ids: torch.Tensor,
|
||||||
if self.repetition_processor is not None:
|
scores: torch.Tensor,
|
||||||
scores = self.repetition_processor(input_ids, scores)
|
speculate: int,
|
||||||
|
speculated_ids: Optional[torch.Tensor] = None,
|
||||||
|
speculative_scores: Optional[torch.Tensor] = None,
|
||||||
|
verbose=False
|
||||||
|
):
|
||||||
|
if speculated_ids is not None:
|
||||||
|
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
|
||||||
|
S = speculated_ids.shape[1] + 1
|
||||||
|
scores = scores.view(B, S, -1)
|
||||||
|
else:
|
||||||
|
B = scores.shape[0]
|
||||||
|
S = 1
|
||||||
|
scores = scores.view(B, S, -1)
|
||||||
|
|
||||||
for warper in self.warpers:
|
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
||||||
scores = warper(input_ids, scores)
|
for j in range(S):
|
||||||
|
_scores = scores[:, j]
|
||||||
|
if self.watermark_processor is not None:
|
||||||
|
_scores = self.watermark_processor(input_ids, _scores)
|
||||||
|
if self.repetition_processor is not None:
|
||||||
|
_scores = self.repetition_processor(input_ids, _scores)
|
||||||
|
|
||||||
|
for warper in self.warpers:
|
||||||
|
_scores = warper(input_ids, _scores)
|
||||||
|
|
||||||
|
_next_ids = self.choice(_scores)
|
||||||
|
scores[:, j] = _scores
|
||||||
|
next_ids[:, j] = _next_ids
|
||||||
|
next_ids = next_ids.view(B*S)
|
||||||
|
scores = scores.view( B* S, -1)
|
||||||
|
|
||||||
|
if speculated_ids is not None:
|
||||||
|
accepted_ids = []
|
||||||
|
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
|
||||||
|
S = speculated_ids.shape[1] + 1
|
||||||
|
indices = []
|
||||||
|
for i in range(B):
|
||||||
|
_next_ids = next_ids[i*S: (i + 1)*S]
|
||||||
|
_speculated_ids = speculated_ids[i]
|
||||||
|
validate_speculative = _next_ids[:-1] == _speculated_ids
|
||||||
|
index = i * S
|
||||||
|
accepted = 1
|
||||||
|
# First is always valid
|
||||||
|
indices.append(index)
|
||||||
|
for valid in validate_speculative.tolist():
|
||||||
|
if valid:
|
||||||
|
index += 1
|
||||||
|
accepted += 1
|
||||||
|
indices.append(index)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
accepted_ids.append(accepted)
|
||||||
|
|
||||||
|
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype)
|
||||||
|
next_ids = next_ids[indices]
|
||||||
|
scores = scores[indices]
|
||||||
|
indices = torch.arange(B, device=input_ids.device) * S
|
||||||
|
if speculative_scores is not None:
|
||||||
|
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
||||||
|
else:
|
||||||
|
accepted_ids = torch.ones_like(next_ids)
|
||||||
|
|
||||||
next_ids = self.choice(scores)
|
|
||||||
# ignore logprobs if we use greedy search
|
# ignore logprobs if we use greedy search
|
||||||
if type(self.choice) == Greedy:
|
if type(self.choice) == Greedy:
|
||||||
logprobs = torch.empty_like(scores, device="cpu")
|
logprobs = torch.empty_like(scores, device="cpu")
|
||||||
@ -224,7 +296,17 @@ class HeterogeneousNextTokenChooser:
|
|||||||
logprobs = torch.log_softmax(scores, -1)
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
|
|
||||||
return next_ids, next_logprobs, logprobs
|
if speculate > 0:
|
||||||
|
if speculative_scores is not None:
|
||||||
|
# Medusa provided some scores
|
||||||
|
speculative_ids = Greedy()(speculative_scores)
|
||||||
|
else:
|
||||||
|
# n-gram
|
||||||
|
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
|
||||||
|
else:
|
||||||
|
speculative_ids = None
|
||||||
|
|
||||||
|
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user