Speculative (#1308)

This commit is contained in:
Nicolas Patry 2023-12-11 12:46:30 +01:00 committed by Karol Damaszke
parent a41c1a6bc7
commit a7f52f3812
35 changed files with 1511 additions and 398 deletions

254
Cargo.lock generated
View File

@ -88,9 +88,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.81"
version = "1.0.82"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247"
checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519"
[[package]]
name = "arc-swap"
@ -128,18 +128,18 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
name = "async-trait"
version = "0.1.79"
version = "0.1.80"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681"
checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -288,9 +288,9 @@ dependencies = [
[[package]]
name = "bumpalo"
version = "3.15.4"
version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa"
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]]
name = "bytecount"
@ -350,9 +350,9 @@ checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
[[package]]
name = "cc"
version = "1.0.90"
version = "1.0.94"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5"
checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7"
[[package]]
name = "cfg-if"
@ -397,7 +397,7 @@ dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -675,9 +675,9 @@ dependencies = [
[[package]]
name = "either"
version = "1.10.0"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2"
[[package]]
name = "encode_unicode"
@ -687,9 +687,9 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"
[[package]]
name = "encoding_rs"
version = "0.8.33"
version = "0.8.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1"
checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59"
dependencies = [
"cfg-if",
]
@ -839,7 +839,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -893,9 +893,9 @@ dependencies = [
[[package]]
name = "getrandom"
version = "0.2.12"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c"
dependencies = [
"cfg-if",
"libc",
@ -920,9 +920,9 @@ dependencies = [
[[package]]
name = "h2"
version = "0.3.25"
version = "0.3.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fbd2820c5e49886948654ab546d0688ff24530286bdcf8fca3cefb16d4618eb"
checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8"
dependencies = [
"bytes",
"fnv",
@ -993,15 +993,6 @@ dependencies = [
"ureq",
]
[[package]]
name = "home"
version = "0.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5"
dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "hostname"
version = "0.3.1"
@ -1204,6 +1195,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.11"
@ -1358,7 +1358,7 @@ checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -1421,9 +1421,9 @@ dependencies = [
[[package]]
name = "monostate"
version = "0.1.11"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "878c2a1f1c70e5724fa28f101ca787b6a7e8ad5c5e4ae4ca3b0fa4a419fa9075"
checksum = "a20fffcd8ca4c69d31e036a71abc400147b41f90895df4edcb36497a1f8af8bf"
dependencies = [
"monostate-impl",
"serde",
@ -1431,20 +1431,20 @@ dependencies = [
[[package]]
name = "monostate-impl"
version = "0.1.11"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f686d68a09079e63b1d2c64aa305095887ce50565f00a922ebfaeeee0d9ba6ce"
checksum = "bf307cbbbd777a9c10cec88ddafee572b3484caad5cce0c9236523c3803105a6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
name = "multimap"
version = "0.8.3"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
[[package]]
name = "muxado"
@ -1662,7 +1662,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -1878,7 +1878,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -1919,12 +1919,12 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "prettyplease"
version = "0.2.17"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7"
checksum = "5ac2cf0f2e4f42b49f5ffd07dae8d746508ef7526c13940e5f524012ae6c6550"
dependencies = [
"proc-macro2",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -1953,9 +1953,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.79"
version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e"
checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba"
dependencies = [
"unicode-ident",
]
@ -1972,34 +1972,33 @@ dependencies = [
[[package]]
name = "prost"
version = "0.12.3"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "146c289cda302b98a28d40c8b3b90498d6e526dd24ac2ecea73e4e491685b94a"
checksum = "d0f5d036824e4761737860779c906171497f6d55681139d8312388f8fe398922"
dependencies = [
"bytes",
"prost-derive 0.12.3",
"prost-derive 0.12.4",
]
[[package]]
name = "prost-build"
version = "0.12.3"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c55e02e35260070b6f716a2423c2ff1c3bb1642ddca6f99e1f26d06268a0e2d2"
checksum = "80b776a1b2dc779f5ee0641f8ade0125bc1298dd41a9a0c16d8bd57b42d222b1"
dependencies = [
"bytes",
"heck 0.4.1",
"itertools 0.11.0",
"heck 0.5.0",
"itertools 0.12.1",
"log",
"multimap",
"once_cell",
"petgraph",
"prettyplease",
"prost 0.12.3",
"prost 0.12.4",
"prost-types",
"regex",
"syn 2.0.58",
"syn 2.0.60",
"tempfile",
"which",
]
[[package]]
@ -2017,24 +2016,24 @@ dependencies = [
[[package]]
name = "prost-derive"
version = "0.12.3"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efb6c9a1dd1def8e2124d17e83a20af56f1570d6c2d2bd9e266ccb768df3840e"
checksum = "19de2de2a00075bf566bee3bd4db014b11587e84184d3f7a791bc17f1a8e9e48"
dependencies = [
"anyhow",
"itertools 0.11.0",
"itertools 0.12.1",
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
name = "prost-types"
version = "0.12.3"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "193898f59edcf43c26227dcd4c8427f00d99d61e95dcde58dabd49fa291d470e"
checksum = "3235c33eb02c1f1e212abdbe34c78b264b038fb58ca612664343271e36e55ffe"
dependencies = [
"prost 0.12.3",
"prost 0.12.4",
]
[[package]]
@ -2055,9 +2054,9 @@ dependencies = [
[[package]]
name = "quote"
version = "1.0.35"
version = "1.0.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef"
checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
dependencies = [
"proc-macro2",
]
@ -2310,7 +2309,7 @@ dependencies = [
"quote",
"rust-embed-utils",
"shellexpand",
"syn 2.0.58",
"syn 2.0.60",
"walkdir",
]
@ -2406,9 +2405,9 @@ dependencies = [
[[package]]
name = "rustversion"
version = "1.0.14"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4"
checksum = "80af6f9131f277a45a3fba6ce8e2258037bb0477a67e610d3c1fe046ab31de47"
[[package]]
name = "ryu"
@ -2484,29 +2483,29 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.197"
version = "1.0.198"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2"
checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.197"
version = "1.0.198"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b"
checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
name = "serde_json"
version = "1.0.115"
version = "1.0.116"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd"
checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813"
dependencies = [
"itoa",
"ryu",
@ -2689,7 +2688,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -2711,9 +2710,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.58"
version = "2.0.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687"
checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3"
dependencies = [
"proc-macro2",
"quote",
@ -2728,9 +2727,9 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
[[package]]
name = "sysinfo"
version = "0.30.8"
version = "0.30.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b1a378e48fb3ce3a5cf04359c456c9c98ff689bcf1c1bc6e6a31f247686f275"
checksum = "26d7c217777061d5a2d652aea771fb9ba98b6dade657204b08c4b9604d11555b"
dependencies = [
"cfg-if",
"core-foundation-sys",
@ -2824,7 +2823,7 @@ version = "1.2.0"
dependencies = [
"futures",
"grpc-metadata",
"prost 0.12.3",
"prost 0.12.4",
"prost-build",
"rand",
"thiserror",
@ -2903,7 +2902,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -2918,9 +2917,9 @@ dependencies = [
[[package]]
name = "time"
version = "0.3.34"
version = "0.3.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749"
checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885"
dependencies = [
"deranged",
"itoa",
@ -2941,9 +2940,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
[[package]]
name = "time-macros"
version = "0.2.17"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774"
checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf"
dependencies = [
"num-conv",
"time-core",
@ -3035,7 +3034,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -3131,7 +3130,7 @@ dependencies = [
"hyper-timeout",
"percent-encoding",
"pin-project",
"prost 0.12.3",
"prost 0.12.4",
"tokio",
"tokio-stream",
"tower",
@ -3150,7 +3149,7 @@ dependencies = [
"proc-macro2",
"prost-build",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -3223,7 +3222,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -3464,7 +3463,7 @@ dependencies = [
"proc-macro2",
"quote",
"regex",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]
@ -3563,7 +3562,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
"wasm-bindgen-shared",
]
@ -3597,7 +3596,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@ -3637,18 +3636,6 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix",
]
[[package]]
name = "winapi"
version = "0.3.9"
@ -3687,7 +3674,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be"
dependencies = [
"windows-core",
"windows-targets 0.52.4",
"windows-targets 0.52.5",
]
[[package]]
@ -3696,7 +3683,7 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
dependencies = [
"windows-targets 0.52.4",
"windows-targets 0.52.5",
]
[[package]]
@ -3723,7 +3710,7 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
"windows-targets 0.52.4",
"windows-targets 0.52.5",
]
[[package]]
@ -3758,17 +3745,18 @@ dependencies = [
[[package]]
name = "windows-targets"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b"
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
dependencies = [
"windows_aarch64_gnullvm 0.52.4",
"windows_aarch64_msvc 0.52.4",
"windows_i686_gnu 0.52.4",
"windows_i686_msvc 0.52.4",
"windows_x86_64_gnu 0.52.4",
"windows_x86_64_gnullvm 0.52.4",
"windows_x86_64_msvc 0.52.4",
"windows_aarch64_gnullvm 0.52.5",
"windows_aarch64_msvc 0.52.5",
"windows_i686_gnu 0.52.5",
"windows_i686_gnullvm",
"windows_i686_msvc 0.52.5",
"windows_x86_64_gnu 0.52.5",
"windows_x86_64_gnullvm 0.52.5",
"windows_x86_64_msvc 0.52.5",
]
[[package]]
@ -3785,9 +3773,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9"
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
[[package]]
name = "windows_aarch64_msvc"
@ -3803,9 +3791,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675"
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
[[package]]
name = "windows_i686_gnu"
@ -3821,9 +3809,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]]
name = "windows_i686_gnu"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3"
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
[[package]]
name = "windows_i686_msvc"
@ -3839,9 +3833,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]]
name = "windows_i686_msvc"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02"
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
[[package]]
name = "windows_x86_64_gnu"
@ -3857,9 +3851,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03"
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
[[package]]
name = "windows_x86_64_gnullvm"
@ -3875,9 +3869,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177"
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
[[package]]
name = "windows_x86_64_msvc"
@ -3893,9 +3887,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.4"
version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
[[package]]
name = "winreg"
@ -3924,7 +3918,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.58",
"syn 2.0.60",
]
[[package]]

View File

@ -67,6 +67,14 @@ Options:
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
```
## SPECULATE
```shell
--speculate <SPECULATE>
The number of input_ids to speculate on If using a medusa model, the heads will be picked up automatically Other wise, it will use n-gram speculation which is relatively free in terms of compute, but the speedup heavily depends on the task
[env: SPECULATE=]
```
## DTYPE
```shell

View File

@ -0,0 +1,98 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 338,
"logprob": -10.0078125,
"text": "is"
},
{
"id": 21784,
"logprob": -15.515625,
"text": "Deep"
},
{
"id": 29257,
"logprob": -2.8847656,
"text": "Learning"
},
{
"id": 29973,
"logprob": -4.140625,
"text": "?"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -1.1582031,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.23083496,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": 0.0,
"special": false,
"text": "ep"
},
{
"id": 6509,
"logprob": 0.0,
"special": false,
"text": " learning"
},
{
"id": 29892,
"logprob": -0.61816406,
"special": false,
"text": ","
},
{
"id": 607,
"logprob": -0.7089844,
"special": false,
"text": " which"
},
{
"id": 508,
"logprob": -1.7724609,
"special": false,
"text": " can"
},
{
"id": 367,
"logprob": 0.0,
"special": false,
"text": " be"
},
{
"id": 5545,
"logprob": 0.0,
"special": false,
"text": " considered"
},
{
"id": 408,
"logprob": -0.3869629,
"special": false,
"text": " as"
}
]
},
"generated_text": "What is Deep Learning?\nDeep learning, which can be considered as"
}

View File

@ -0,0 +1,414 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1724,
"logprob": -10.734375,
"text": "What"
},
{
"id": 338,
"logprob": -1.5488281,
"text": "is"
},
{
"id": 21784,
"logprob": -9.2890625,
"text": "Deep"
},
{
"id": 29257,
"logprob": -1.2753906,
"text": "Learning"
},
{
"id": 29973,
"logprob": -0.48046875,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.1845703,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.5727539,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": -0.00010967255,
"special": false,
"text": "ep"
},
{
"id": 6509,
"logprob": -0.1239624,
"special": false,
"text": " learning"
},
{
"id": 338,
"logprob": -0.04510498,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.018295288,
"special": false,
"text": " a"
},
{
"id": 11306,
"logprob": -0.45922852,
"special": false,
"text": " subset"
},
{
"id": 310,
"logprob": -0.00020992756,
"special": false,
"text": " of"
},
{
"id": 4933,
"logprob": -0.0046539307,
"special": false,
"text": " machine"
},
{
"id": 6509,
"logprob": -0.00025844574,
"special": false,
"text": " learning"
}
]
},
"generated_text": "\nDeep learning is a subset of machine learning"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1724,
"logprob": -10.734375,
"text": "What"
},
{
"id": 338,
"logprob": -1.5488281,
"text": "is"
},
{
"id": 21784,
"logprob": -9.2890625,
"text": "Deep"
},
{
"id": 29257,
"logprob": -1.2724609,
"text": "Learning"
},
{
"id": 29973,
"logprob": -0.47729492,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.1826172,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.56689453,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": -0.000108003616,
"special": false,
"text": "ep"
},
{
"id": 6509,
"logprob": -0.1239624,
"special": false,
"text": " learning"
},
{
"id": 338,
"logprob": -0.044433594,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.018295288,
"special": false,
"text": " a"
},
{
"id": 11306,
"logprob": -0.45922852,
"special": false,
"text": " subset"
},
{
"id": 310,
"logprob": -0.0002104044,
"special": false,
"text": " of"
},
{
"id": 4933,
"logprob": -0.004711151,
"special": false,
"text": " machine"
},
{
"id": 6509,
"logprob": -0.00025892258,
"special": false,
"text": " learning"
}
]
},
"generated_text": "\nDeep learning is a subset of machine learning"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1724,
"logprob": -10.734375,
"text": "What"
},
{
"id": 338,
"logprob": -1.5488281,
"text": "is"
},
{
"id": 21784,
"logprob": -9.2890625,
"text": "Deep"
},
{
"id": 29257,
"logprob": -1.2724609,
"text": "Learning"
},
{
"id": 29973,
"logprob": -0.47729492,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.1826172,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.56689453,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": -0.000108003616,
"special": false,
"text": "ep"
},
{
"id": 6509,
"logprob": -0.1239624,
"special": false,
"text": " learning"
},
{
"id": 338,
"logprob": -0.044433594,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.018295288,
"special": false,
"text": " a"
},
{
"id": 11306,
"logprob": -0.45922852,
"special": false,
"text": " subset"
},
{
"id": 310,
"logprob": -0.0002104044,
"special": false,
"text": " of"
},
{
"id": 4933,
"logprob": -0.004711151,
"special": false,
"text": " machine"
},
{
"id": 6509,
"logprob": -0.00025892258,
"special": false,
"text": " learning"
}
]
},
"generated_text": "\nDeep learning is a subset of machine learning"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1724,
"logprob": -10.734375,
"text": "What"
},
{
"id": 338,
"logprob": -1.5488281,
"text": "is"
},
{
"id": 21784,
"logprob": -9.2890625,
"text": "Deep"
},
{
"id": 29257,
"logprob": -1.2724609,
"text": "Learning"
},
{
"id": 29973,
"logprob": -0.47729492,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.1826172,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.56689453,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": -0.000108003616,
"special": false,
"text": "ep"
},
{
"id": 6509,
"logprob": -0.1239624,
"special": false,
"text": " learning"
},
{
"id": 338,
"logprob": -0.044433594,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.018295288,
"special": false,
"text": " a"
},
{
"id": 11306,
"logprob": -0.45922852,
"special": false,
"text": " subset"
},
{
"id": 310,
"logprob": -0.0002104044,
"special": false,
"text": " of"
},
{
"id": 4933,
"logprob": -0.004711151,
"special": false,
"text": " machine"
},
{
"id": 6509,
"logprob": -0.00025892258,
"special": false,
"text": " learning"
}
]
},
"generated_text": "\nDeep learning is a subset of machine learning"
}
]

View File

@ -0,0 +1,103 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1724,
"logprob": -10.734375,
"text": "What"
},
{
"id": 338,
"logprob": -1.5488281,
"text": "is"
},
{
"id": 21784,
"logprob": -9.2890625,
"text": "Deep"
},
{
"id": 29257,
"logprob": -1.2753906,
"text": "Learning"
},
{
"id": 29973,
"logprob": -0.48046875,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.1845703,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.5727539,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": -0.000108122826,
"special": false,
"text": "ep"
},
{
"id": 6509,
"logprob": -0.1239624,
"special": false,
"text": " learning"
},
{
"id": 338,
"logprob": -0.044433594,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.01852417,
"special": false,
"text": " a"
},
{
"id": 11306,
"logprob": -0.45922852,
"special": false,
"text": " subset"
},
{
"id": 310,
"logprob": -0.0002104044,
"special": false,
"text": " of"
},
{
"id": 4933,
"logprob": -0.004787445,
"special": false,
"text": " machine"
},
{
"id": 6509,
"logprob": -0.00026226044,
"special": false,
"text": " learning"
}
]
},
"generated_text": "\nDeep learning is a subset of machine learning"
}

View File

@ -0,0 +1,59 @@
import pytest
@pytest.fixture(scope="module")
def flash_medusa_handle(launcher):
with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_medusa(flash_medusa_handle):
await flash_medusa_handle.health(300)
return flash_medusa_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_simple(flash_medusa, response_snapshot):
response = await flash_medusa.generate(
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
response = await flash_medusa.generate(
"What is Deep Learning?",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
responses = await generate_load(flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == '\nDeep learning is a subset of machine learning'
assert responses == response_snapshot

View File

@ -21,6 +21,7 @@ async def test_flash_mistral(flash_mistral, response_snapshot):
)
assert response.details.generated_tokens == 10
assert response.generated_text == ": Let n = 10 - 1"
assert response == response_snapshot
@ -55,6 +56,7 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ": Let n = 10 - 1"
assert responses == response_snapshot

View File

@ -157,6 +157,13 @@ struct Args {
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,
/// The number of input_ids to speculate on
/// If using a medusa model, the heads will be picked up automatically
/// Other wise, it will use n-gram speculation which is relatively free
/// in terms of compute, but the speedup heavily depends on the task.
#[clap(long, env)]
speculate: Option<usize>,
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,
@ -377,6 +384,7 @@ fn shard_manager(
model_id: String,
revision: Option<String>,
quantize: Option<Quantization>,
speculate: Option<usize>,
dtype: Option<Dtype>,
max_total_tokens: usize,
trust_remote_code: bool,
@ -435,6 +443,11 @@ fn shard_manager(
shard_args.push(quantize.to_string())
}
if let Some(speculate) = speculate {
shard_args.push("--speculate".to_string());
shard_args.push(speculate.to_string())
}
if let Some(dtype) = dtype {
shard_args.push("--dtype".to_string());
shard_args.push(dtype.to_string())
@ -890,6 +903,7 @@ fn spawn_shards(
let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone();
let quantize = args.quantize;
let speculate = args.speculate;
let dtype = args.dtype;
let max_total_tokens = args.max_total_tokens;
let trust_remote_code = args.trust_remote_code;
@ -905,6 +919,7 @@ fn spawn_shards(
model_id,
revision,
quantize,
speculate,
dtype,
max_total_tokens,
trust_remote_code,

View File

@ -7,7 +7,9 @@ const seed = 0;
const host = __ENV.HOST || '127.0.0.1:8000';
const timePerToken = new Trend('time_per_token', true);
const throughput = new Counter('tokens_per_s');
const tokens = new Counter('tokens');
const new_tokens = new Counter('new_tokens');
const input_tokens = new Counter('input_tokens');
randomSeed(seed);
// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json"))
@ -19,7 +21,7 @@ export function get_options(reference_latency_ms){
thresholds: {
http_req_failed: ['rate==0'],
time_per_token: [{
threshold: `p(50)<${3 * reference_latency_ms}`,
threshold: `p(50)<${5 * reference_latency_ms}`,
abortOnFail: true,
delayAbortEval: '10s'
}],
@ -28,7 +30,7 @@ export function get_options(reference_latency_ms){
load_test: {
executor: 'constant-arrival-rate',
duration: '60s',
preAllocatedVUs: 100,
preAllocatedVUs: 10,
rate: 10,
timeUnit: '1s',
},
@ -48,17 +50,22 @@ export function run(host, generate_payload, max_new_tokens) {
return;
}
check(res, {
'Post status is 200': (r) => res.status === 200,
});
const n_tokens = max_new_tokens;
const timings = res.timings.duration;
const duration = res.timings.duration;
if (res.status === 200) {
const latency_ms_per_token = timings / n_tokens;
const body = res.json();
const n_tokens = body.details.tokens.length;
const latency_ms_per_token = duration / n_tokens;
timePerToken.add(latency_ms_per_token);
const latency_in_s = latency_ms_per_token / 1000;
const individual_throughput = 1 / latency_in_s;
throughput.add(individual_throughput);
const _input_tokens = body.details.prefill.length;
tokens.add(n_tokens + _input_tokens);
input_tokens.add(_input_tokens);
new_tokens.add(n_tokens);
}
}

View File

@ -1,13 +1,13 @@
import { get_options, run } from "./common.js";
const reference_latency_ms = 30;
const reference_latency_ms = 70;
const host = __ENV.HOST || '127.0.0.1:8000';
const max_new_tokens = 50;
function generate_payload(gpt){
const input = gpt["conversations"][0]["value"];
return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "temperature" : 0.5}}
return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "decoder_input_details": true}}
}
export const options = get_options(reference_latency_ms);

View File

@ -1,6 +1,6 @@
syntax = "proto3";
package generate.v1;
package generate.v2;
service TextGenerationService {
/// Model Info
@ -32,6 +32,7 @@ message InfoResponse {
string dtype = 2;
string device_type = 3;
optional uint32 window_size = 4;
uint32 speculate = 5;
}
/// Empty request
@ -135,43 +136,27 @@ message GeneratedText {
optional uint64 seed = 4;
}
message PrefillTokens {
/// Prefill Token IDs
message Tokens {
/// Token IDs
repeated uint32 ids = 1;
/// Prefill Logprobs
/// Logprobs
repeated float logprobs = 2;
/// Prefill tokens
/// tokens
repeated string texts = 3;
}
message TopTokens {
/// Top Token IDs
repeated uint32 ids = 1;
/// Top Logprobs
repeated float logprobs = 2;
/// Top Token Texts
repeated string texts = 3;
/// If the tokens are special
repeated bool is_special = 6;
/// special
repeated bool is_special = 4;
}
message Generation {
/// Request ID
uint64 request_id = 1;
/// Prefill tokens (optional)
PrefillTokens prefill_tokens = 2;
/// Token ID
uint32 token_id = 3;
/// Logprob
float token_logprob = 4;
/// Text
string token_text = 5;
/// Is it a special token
bool token_is_special = 6;
Tokens prefill_tokens = 2;
Tokens tokens = 3;
/// Complete generated text
optional GeneratedText generated_text = 7;
optional GeneratedText generated_text = 4;
/// Top tokens
TopTokens top_tokens = 8;
repeated Tokens top_tokens = 5;
}
message FilterBatchRequest {

View File

@ -1,8 +1,8 @@
/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
/// Single shard Client
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*;
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v2::*;
use crate::Result;
use std::env;
use rand::{distributions::Uniform, Rng};

View File

@ -6,11 +6,11 @@ mod pb;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v1::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
PrefillTokens, Request, StoppingCriteriaParameters,
Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;

View File

@ -11,7 +11,7 @@ use std::sync::{
Arc,
};
use text_generation_client::{
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens,
};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
@ -54,6 +54,7 @@ impl Infer {
max_input_length: u32,
max_total_tokens: u32,
window_size: Option<u32>,
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
// Infer shared state
@ -62,7 +63,8 @@ impl Infer {
max_input_length,
max_total_tokens,
16,
window_size
window_size,
speculate
);
let shared = Arc::new(Shared {
batching_task: Notify::new(),
@ -533,50 +535,63 @@ fn send_responses(
}
// Create last Token
let token = Token {
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
special: generation.token_is_special,
};
// generation.top_tokens
let mut top_tokens = Vec::new();
if let Some(top_tokens_) = generation.top_tokens {
top_tokens.extend(
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter())
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
let token = Token {
id,
text,
logprob,
special,
};
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
top_tokens_
.ids
.into_iter()
.zip(top_tokens_.logprobs.into_iter())
.zip(top_tokens_.texts.into_iter())
.zip(top_tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
.iter()
.zip(top_tokens_.logprobs.iter())
.zip(top_tokens_.texts.iter())
.zip(top_tokens_.is_special.iter())
.map(|(((&id, &logprob), text), &special)| Token {
id,
text,
text: text.to_string(),
logprob,
special,
}),
)
})
.collect()
} else {
vec![]
};
match (&generation.generated_text, iterator.peek()) {
(Some(generated_text), None) => {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text: generated_text.clone(),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
}
_ => {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
}
}
}
if let Some(generated_text) = generation.generated_text {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
} else {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
}
Ok(stopped)
}
@ -601,7 +616,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(PrefillTokens),
Prefill(Tokens),
// Intermediate messages
Intermediate {
token: Token,

View File

@ -44,7 +44,8 @@ impl Queue {
max_input_length: u32,
max_total_tokens: u32,
block_size: u32,
window_size: Option<u32>
window_size: Option<u32>,
speculate: u32,
) -> Self {
// Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
@ -56,6 +57,7 @@ impl Queue {
max_total_tokens,
block_size,
window_size,
speculate,
queue_receiver,
));
@ -106,6 +108,7 @@ async fn queue_task(
max_total_tokens: u32,
block_size: u32,
window_size: Option<u32>,
speculate: u32,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) {
let mut state = State::new(
@ -113,7 +116,8 @@ async fn queue_task(
max_input_length,
max_total_tokens,
block_size,
window_size
window_size,
speculate
);
while let Some(cmd) = receiver.recv().await {
@ -256,6 +260,9 @@ struct State {
/// Sliding window
window_size: Option<u32>,
/// Speculation amount
speculate: u32,
}
impl State {
@ -265,6 +272,7 @@ impl State {
max_total_tokens: u32,
block_size: u32,
window_size: Option<u32>,
speculate: u32,
) -> Self {
let default_threshold: u64 = 120;
let threshold: u64 = match env::var("QUEUE_THRESHOLD_MS") {
@ -281,6 +289,7 @@ impl State {
max_total_tokens,
block_size,
window_size,
speculate,
}
}
@ -365,7 +374,7 @@ impl State {
}
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens) > token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
@ -457,13 +466,13 @@ mod tests {
fn default_queue() -> Queue {
Queue::new(
true, 1, 2, 1, None
true, 1, 2, 1, None, 0
)
}
fn default_state() -> State {
State::new(
true, 1, 2, 1, None
true, 1, 2, 1, None, 0
)
}
@ -667,6 +676,25 @@ mod tests {
assert_eq!(batch.size, 2);
}
#[tokio::test]
async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(true, 1, 2, 1, None, 2);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
// Budget of 1 is not enough
assert!(queue.next_batch(None, 1, 1).await.is_none());
let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 2);
}
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = default_queue();

View File

@ -600,6 +600,7 @@ pub async fn run(
max_input_length as u32,
max_total_tokens as u32,
shard_info.window_size,
shard_info.speculate,
generation_health,
);

View File

@ -1,22 +1,25 @@
build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git
build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
build-vllm-cuda: BRANCH=main
build-vllm-cuda: build-vllm
build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git
build-vllm-rocm: VLLM_COMMIT=ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
build-vllm-rocm: BRANCH=rotary-no-positions-split-cos-sin
build-vllm-rocm: build-vllm
vllm:
vllm-cuda:
# Clone vllm
pip install -U ninja packaging --no-cache-dir
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
git clone https://github.com/vllm-project/vllm.git vllm
build-vllm: vllm
cd vllm && git fetch && git checkout $(VLLM_COMMIT)
build-vllm-cuda: vllm-cuda
cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
cd vllm && python setup.py build
install-vllm: build-vllm
install-vllm-cuda: build-vllm-cuda
pip uninstall vllm -y || true
cd vllm && python setup.py install
vllm-rocm:
# Clone vllm
pip install -U ninja packaging --no-cache-dir
git clone https://github.com/fxmarty/vllm-public.git vllm
build-vllm-rocm: vllm-rocm
cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
cd vllm && python setup.py build
install-vllm-rocm: build-vllm-rocm
pip uninstall vllm -y || true
cd vllm && python setup.py install

View File

@ -135,8 +135,8 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 10264 for generation in generations])
assert all([generation.token_text == "Test" for generation in generations])
assert all([token_id.item() == 10264 for generation in generations for token_id in generation.tokens.token_ids])
assert all([token_text == "Test" for generation in generations for token_text in generation.tokens.texts])
assert generations[0].request_id == 0

View File

@ -141,8 +141,8 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 13 for generation in generations])
assert all([generation.token_text == "." for generation in generations])
assert all([token_id.item() == 13 for generation in generations for token_id in generation.tokens.token_ids])
assert all([token_text == "." for generation in generations for token_text in generation.tokens.texts])
assert generations[0].request_id == 0

View File

@ -155,8 +155,8 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 259 for generation in generations])
assert all([generation.token_text == " " for generation in generations])
assert all([token_id.item() == 259 for generation in generations for token_id in generation.tokens.token_ids])
assert all([token_text == " " for generation in generations for token_text in generation.tokens.texts])
assert generations[0].request_id == 0

View File

@ -10,6 +10,7 @@ from pathlib import Path
from loguru import logger
from typing import Optional
from enum import Enum
from huggingface_hub import hf_hub_download
app = typer.Typer()
@ -31,6 +32,7 @@ def serve(
revision: Optional[str] = None,
sharded: bool = False,
quantize: Optional[Quantization] = None,
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
@ -39,9 +41,15 @@ def serve(
otlp_endpoint: Optional[str] = None,
):
if sharded:
assert os.getenv("WORLD_SIZE", None) is not None, "WORLD_SIZE must be set when sharded is True"
assert os.getenv("MASTER_ADDR", None) is not None, "MASTER_ADDR must be set when sharded is True"
assert os.getenv("MASTER_PORT", None) is not None, "MASTER_PORT must be set when sharded is True"
assert (
os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True"
assert (
os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
# Remove default handler
logger.remove()
@ -75,7 +83,11 @@ def serve(
logger.info("CLI SHARDED = {}".format(num_shard))
import subprocess
cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file} --model_id {model_id} --revision {revision} --sharded {sharded} --dtype {dtype} --uds_path {uds_path}"
cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
if speculate is not None:
cmd += f"--speculate {speculate}"
logger.info("CLI server start deepspeed ={} ".format(cmd))
sys.stdout.flush()
sys.stderr.flush()
@ -119,7 +131,9 @@ def serve(
logger.error(f"{cmd} exited with status = {proc.returncode}")
return proc.returncode
else:
server.serve(model_id, revision, dtype, uds_path, sharded)
server.serve(
model_id, revision, sharded, speculate, dtype, trust_remote_code, uds_path
)
@app.command()
@ -153,7 +167,7 @@ def download_weights(
logger.info("Files are already present on the host. " "Skipping download.")
return
# Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError):
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
pass
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
@ -161,6 +175,42 @@ def download_weights(
) is not None
if not is_local_model:
try:
adapter_config_filename = hf_hub_download(
model_id, revision=revision, filename="adapter_config.json"
)
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
is_local_model = True
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
try:
import json
medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")
if auto_convert:
medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors")
if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json")
with open(medusa_config, "r") as f:
config = json.load(f)
model_id = config["base_model_name_or_path"]
revision = "main"
try:
utils.weight_files(model_id, revision, extension)
logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.")
return
# Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
# Try to download weights from the hub
try:
filenames = utils.weight_hub_files(model_id, revision, extension)

View File

@ -1,10 +1,14 @@
import torch
from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from transformers import AutoConfig
from typing import Optional
# Needed to properly setup habana_frameworks
import text_generation_server.habana_quantization_env as hq_env
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM
@ -18,10 +22,46 @@ torch.set_grad_enabled(False)
def get_model(
model_id: str,
revision: Optional[str],
dtype: Optional[torch.dtype] = None,
speculate: Optional[int],
dtype: Optional[torch.dtype],
trust_remote_code: bool,
) -> Model:
config = AutoConfig.from_pretrained(model_id, revision=revision)
model_type = config.model_type
if speculate is not None:
set_speculate(speculate)
else:
set_speculate(0)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
use_medusa = None
if "medusa_num_heads" in config_dict:
use_medusa = model_id
medusa_config = config_dict
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
else:
set_speculate(speculate)
else:
set_speculate(speculate_medusa)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
method = "medusa"
else:
method = "n-gram"
speculate = get_speculate()
if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict["model_type"]
if model_type == "gpt_bigcode":
return SantaCoder(model_id, revision, dtype)

View File

@ -35,10 +35,9 @@ from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Tokens,
Generation,
GeneratedText,
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
@ -48,6 +47,7 @@ from text_generation_server.utils import (
is_tokenizer_transparent,
)
from text_generation_server.utils.debug import dbg_trace
from text_generation_server.utils.speculate import get_speculate
tracer = trace.get_tracer(__name__)
@ -647,6 +647,8 @@ class CausalLM(Model):
kwargs["attn_softmax_bf16"] = True
kwargs["trim_logits"] = True
self.speculate = get_speculate()
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
@ -842,12 +844,12 @@ class CausalLM(Model):
# Select next token
input_length = batch.input_length
if logits.shape[-2] > 1:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2)
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate
)
else:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids, logits.squeeze(-2)
next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser(
batch.input_ids, logits.squeeze(-2), self.speculate
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
@ -1017,7 +1019,9 @@ class CausalLM(Model):
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(prefill_token_ids, prefill_logprobs, prefill_texts)
prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
)
else:
prefill_tokens = None
@ -1027,8 +1031,10 @@ class CausalLM(Model):
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [token_id in self.all_special_ids for token_id in top_token_ids]
top_tokens = TopTokens(
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
@ -1040,10 +1046,12 @@ class CausalLM(Model):
generation = Generation(
request.id,
prefill_tokens,
next_token_id,
next_token_logprob,
next_token_text,
next_token_id in self.all_special_ids,
Tokens(
[next_token_id],
[next_token_logprob],
[next_token_text],
[next_token_id in self.all_special_ids],
),
generated_text,
top_tokens,
)

View File

@ -11,13 +11,13 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model
from text_generation_server.models import Model
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Tokens,
Generation,
GeneratedText,
TopTokens,
)
from text_generation_server.models.cache_manager import (
get_cache_manager,
@ -41,6 +41,7 @@ class FlashCausalLMBatch(Batch):
# Decoder values
input_ids: torch.Tensor
position_ids: torch.Tensor
speculative_ids: torch.Tensor
# Flash Attention values
@ -120,6 +121,7 @@ class FlashCausalLMBatch(Batch):
)["input_ids"]
position_ids = []
speculative_ids = []
cu_seqlen_prefill = [0]
needed_blocks_slots = []
start_slots = []
@ -163,6 +165,8 @@ class FlashCausalLMBatch(Batch):
input_length = len(tokenized_input)
input_lengths.append(input_length)
prefix_offsets.append(input_length - 5)
read_offsets.append(input_length)
@ -186,7 +190,8 @@ class FlashCausalLMBatch(Batch):
# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
speculative_length = get_speculate()
total_tokens = input_length + max_new_tokens - 1 + speculative_length
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens))
@ -224,7 +229,7 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens)
max_length = max(max_length, input_length + max_new_tokens + speculative_length)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
@ -255,7 +260,6 @@ class FlashCausalLMBatch(Batch):
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
@ -309,6 +313,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
speculative_ids=None,
)
@tracer.start_as_current_span("filter")
@ -419,6 +424,7 @@ class FlashCausalLMBatch(Batch):
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -454,6 +460,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
)
@classmethod
@ -473,6 +480,7 @@ class FlashCausalLMBatch(Batch):
total_batch_size += len(b)
total_slots += len(b.slots)
blocks += b.blocks
speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max(
@ -480,6 +488,7 @@ class FlashCausalLMBatch(Batch):
max(
input_length
+ stopping_criteria.max_new_tokens
+ speculative_length
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
b.input_lengths, b.stopping_criterias
@ -577,6 +586,8 @@ class FlashCausalLMBatch(Batch):
device=batches[0].next_token_chooser.device,
)
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None
# Needed to avoid dropping blocks when the batches will go out of scope
for b in batches:
b.block_tables = None
@ -611,6 +622,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids
)
def __del__(self):
@ -714,16 +726,55 @@ class FlashCausalLM(Model):
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
if batch.speculative_ids is not None:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
# Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
return self.model.forward(
input_ids=batch.input_ids,
position_ids=batch.position_ids,
cu_seqlen_prefill=batch.cu_seqlen_prefill,
kv_cache=get_cache_manager().kv_cache,
block_tables=batch.block_tables_tensor,
slots=batch.slots[batch.slot_indices],
input_lengths=batch.input_lengths_tensor,
max_s=batch.max_seqlen,
lm_head_indices=batch.prefill_head_indices,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
lm_head_indices=lm_head_indices,
)
@tracer.start_as_current_span("generate_token")
@ -752,21 +803,32 @@ class FlashCausalLM(Model):
del batch
raise e
if isinstance(out, tuple):
out, speculative_logits = out
else:
speculative_logits = None
if prefill:
next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out
)
if speculative_logits is not None:
speculative_logits = (
speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits
)
else:
next_token_logits = out
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
)
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
if prefill:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -792,6 +854,7 @@ class FlashCausalLM(Model):
iterator = zip(
batch.input_lengths,
batch.all_input_ids,
accepted_ids
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
@ -799,9 +862,11 @@ class FlashCausalLM(Model):
# It is faster if we delay this sync for the maximum amount of time
# For each member of the batch
index = 0
for i, (
input_length,
all_input_ids,
n_accepted_ids
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
@ -830,15 +895,18 @@ class FlashCausalLM(Model):
start_index + 1 : start_index + out_length
]
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
index += 1
cumulative_length += input_length
# Set values in batch
batch.input_ids = next_input_ids
batch.position_ids = next_position_ids + 1
batch.input_lengths_tensor += 1
batch.slot_indices += 1
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids
if prefill and prefill_logprobs:
# Get prefill logprobs
@ -851,7 +919,7 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = batch.input_ids.tolist()
next_token_ids = next_input_ids.tolist()
# Zipped iterator
iterator = zip(
@ -864,13 +932,13 @@ class FlashCausalLM(Model):
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
next_token_ids,
next_token_logprobs,
accepted_ids,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
index = 0
for i, (
request,
input_length,
@ -881,29 +949,43 @@ class FlashCausalLM(Model):
do_sample,
seed,
top_n_tokens,
next_token_id,
next_token_logprob,
n_accepted_ids,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Append next token to all tokens
all_input_ids.append(next_token_id)
next_token_texts = []
left = 0
before = stopping_criteria.current_tokens
# Generated token
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
if not stop:
stopped = False
if stop:
left = index + n_accepted_ids - j - 1
current_stopped = True
break
else:
current_stopped = False
stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
index += n_accepted_ids
# Shard generations
# All generations will be appended in the rust sharded client
@ -943,8 +1025,9 @@ class FlashCausalLM(Model):
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts
prefill_tokens = Tokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
)
else:
prefill_tokens = None
@ -958,7 +1041,7 @@ class FlashCausalLM(Model):
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
@ -970,10 +1053,12 @@ class FlashCausalLM(Model):
generation = Generation(
request.id,
prefill_tokens,
next_token_id,
next_token_logprob,
next_token_text,
next_token_id in self.all_special_ids,
Tokens(
_next_token_ids,
_next_token_logprobs,
next_token_texts,
[nid in self.all_special_ids for nid in _next_token_ids],
),
generated_text,
top_tokens,
)
@ -981,7 +1066,9 @@ class FlashCausalLM(Model):
generations.append(generation)
# Update values
batch.input_lengths[i] = input_length + 1
batch.input_lengths[i] = input_length + n_accepted_ids.item()
if batch.input_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i]
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
@ -994,6 +1081,5 @@ class FlashCausalLM(Model):
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1
return generations, batch

View File

@ -28,6 +28,7 @@ class FlashLlama(FlashCausalLM):
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -66,6 +67,18 @@ class FlashLlama(FlashCausalLM):
weights._set_gptq_params(model_id)
model = FlashLlamaForCausalLM(config, weights)
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt")
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors"
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(

View File

@ -21,6 +21,7 @@ from text_generation_server.models.custom_modeling.flash_mistral_modeling import
FlashMistralForCausalLM,
MistralConfig,
)
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
@ -132,7 +133,8 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
speculative_length = get_speculate()
total_tokens = input_length + max_new_tokens - 1 + speculative_length
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks = min(
@ -183,7 +185,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens)
max_length = max(max_length, input_length + max_new_tokens + speculative_length)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
@ -272,6 +274,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
blocks=blocks,
max_blocks=max_blocks,
prefill_cache_indices=prefill_cache_indices,
speculative_ids=None
)
@ -340,17 +343,55 @@ class FlashMistral(FlashCausalLM):
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
if batch.speculative_ids is not None:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
# Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
logits = self.model.forward(
input_ids=batch.input_ids,
position_ids=batch.position_ids,
cu_seqlen_prefill=batch.cu_seqlen_prefill,
kv_cache=get_cache_manager().kv_cache,
block_tables=batch.block_tables_tensor,
slots=batch.slots[batch.slot_indices],
input_lengths=batch.input_lengths_tensor,
max_s=batch.max_seqlen,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=batch.prefill_head_indices,
lm_head_indices=lm_head_indices,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None

View File

@ -20,7 +20,7 @@ from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Tokens,
Generation,
GeneratedText,
)
@ -791,8 +791,8 @@ class IdeficsCausalLM(Model):
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
)
else:
prefill_tokens = None
@ -802,10 +802,12 @@ class IdeficsCausalLM(Model):
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)

View File

@ -5,7 +5,8 @@ from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Type, TypeVar
from transformers import PreTrainedTokenizerBase
from text_generation_server.models.types import Batch, GeneratedText
from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch)
@ -22,6 +23,7 @@ class Model(ABC):
rank: int = 0,
world_size: int = 1,
kwargs: dict = {},
speculate: Optional[int] = None,
):
self.model = model
self.tokenizer = tokenizer
@ -32,7 +34,14 @@ class Model(ABC):
self.rank = rank
self.world_size = world_size
self.kwargs = kwargs
self.has_position_ids = inspect.signature(model.forward).parameters.get("position_ids", None) is not None
if speculate is None:
speculate = get_speculate()
self.speculate = speculate
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.check_initialized()
@ -42,6 +51,7 @@ class Model(ABC):
requires_padding=self.requires_padding,
dtype=str(self.dtype),
device_type=self.device.type,
speculate=self.speculate
)
@property
@ -50,7 +60,7 @@ class Model(ABC):
raise NotImplementedError
@abstractmethod
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]:
raise NotImplementedError
def warmup(self, batch: B, max_total_tokens: int):

View File

@ -11,8 +11,7 @@ from text_generation_server.models.types import (
GeneratedText,
Batch,
Generation,
PrefillTokens,
TopTokens,
Tokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -733,10 +732,11 @@ class Seq2SeqLM(Model):
# Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill_tokens = PrefillTokens(
prefill_tokens = Tokens(
[self.tokenizer.bos_token_id],
[float("nan")],
[self.tokenizer.bos_token],
[False]
)
else:
prefill_tokens = None
@ -750,7 +750,7 @@ class Seq2SeqLM(Model):
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
@ -762,10 +762,12 @@ class Seq2SeqLM(Model):
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)

View File

@ -58,33 +58,15 @@ class GeneratedText:
@dataclass
class PrefillTokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
def to_pb(self) -> generate_pb2.PrefillTokens:
return generate_pb2.PrefillTokens(
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
)
def __len__(self):
return len(self.token_ids)
@dataclass
class TopTokens:
class Tokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
is_special: List[bool]
def to_pb(self) -> generate_pb2.TopTokens:
return generate_pb2.TopTokens(
ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
def to_pb(self) -> generate_pb2.Tokens:
return generate_pb2.Tokens(
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special
)
def __len__(self):
@ -94,14 +76,11 @@ class TopTokens:
@dataclass
class Generation:
request_id: int
prefill_tokens: Optional[PrefillTokens]
token_id: int
token_logprob: float
token_text: str
token_is_special: bool
prefill_tokens: Optional[Tokens]
tokens: Tokens
generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model.
top_tokens: Optional[TopTokens]
top_tokens: Optional[List[Tokens]]
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
@ -109,10 +88,7 @@ class Generation:
prefill_tokens=self.prefill_tokens.to_pb()
if self.prefill_tokens is not None
else None,
token_id=self.token_id,
token_logprob=self.token_logprob,
token_text=self.token_text,
token_is_special=self.token_is_special,
tokens=self.tokens.to_pb(),
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,

View File

@ -107,9 +107,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve(
model_id: str,
revision: Optional[str],
dtype: Optional[str],
uds_path: Path,
sharded: bool,
speculate: Optional[int],
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
):
# Remove default handler
logger.remove()
@ -126,8 +128,10 @@ def serve(
async def serve_inner(
model_id: str,
revision: Optional[str],
dtype: Optional[str] = None,
sharded: bool = False,
speculate: Optional[int] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
):
unix_socket_template = "unix://{}-{}"
logger.info("Server:server_inner: sharded ={}".format(sharded))
@ -151,7 +155,9 @@ def serve(
if revision == "None":
revision = None
try:
model = get_model(model_id, revision=revision, dtype=data_type)
model = get_model(
model_id, revision, speculate, dtype=data_type, trust_remote_code=trust_remote_code
)
except Exception:
logger.exception("Error when initializing model")
raise
@ -181,13 +187,7 @@ def serve(
except KeyboardInterrupt:
logger.info("Signal received. Shutting down")
await server.stop(0)
finally:
if hasattr(model,'finish_quantization_measurements'):
model.finish_quantization_measurements()
logger.info(
"Starting Server : model_id= {}, revision = {} dtype = {} sharded = {} ".format(
model_id, revision, dtype, sharded
)
asyncio.run(
serve_inner(model_id, revision, sharded, speculate, dtype, trust_remote_code)
)
asyncio.run(serve_inner(model_id, revision, dtype, sharded))

View File

@ -9,12 +9,18 @@ import argparse
def main(args):
logger.info("TGIService: starting tgi service .... ")
logger.info(
"TGIService: --model_id {}, --revision {}, --sharded {}, --dtype {}, --uds_path {} ".format(
args.model_id, args.revision, args.sharded, args.dtype, args.uds_path
"TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format(
args.model_id, args.revision, args.sharded, args.speculate, args.dtype, args.trust_remote_code, args.uds_path
)
)
server.serve(
model_id=args.model_id, revision=args.revision, dtype=args.dtype, uds_path=args.uds_path, sharded=args.sharded
model_id=args.model_id,
revision=args.revision,
sharded=args.sharded,
speculate=args.speculate,
dtype=args.dtype,
trust_remote_code=args.trust_remote_code,
uds_path=args.uds_path,
)
@ -23,7 +29,9 @@ if __name__ == "__main__":
parser.add_argument("--model_id", type=str)
parser.add_argument("--revision", type=str)
parser.add_argument("--sharded", type=bool)
parser.add_argument("--speculate", type=int, default=None)
parser.add_argument("--dtype", type=str)
parser.add_argument("--trust_remote_code", type=bool)
parser.add_argument("--uds_path", type=Path)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,51 @@
import torch
from dataclasses import dataclass
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
@dataclass
class Output:
logits: torch.FloatTensor = None
speculative_logits: torch.FloatTensor = None
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True)
self.act = torch.nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
class MedusaModel(torch.nn.Module):
def __init__(
self,
config,
weights,
lm_head
):
super().__init__()
self.heads = torch.nn.ModuleList(
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])]
)
self.lm_head = lm_head
def forward(self, x):
logits = self.lm_head(x)
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return logits, speculative_logits
class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])])
n = len(self.blocks)
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.out(x)
return x

View File

@ -0,0 +1,12 @@
SPECULATE = None
def get_speculate() -> int:
global SPECULATE
return SPECULATE
def set_speculate(speculate: int):
global SPECULATE
SPECULATE = speculate

View File

@ -142,6 +142,22 @@ class StoppingCriteria:
)
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
# Very trivial approach, find first match in the string.
# This is much less refined than actual n-gram but seems to work
# relatively OK in grounded mode and is by far much faster with
# much less worst case complexity as everything happens on device.
B = accepted_ids.shape[0]
device = input_ids.device
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
return speculative_ids
class HeterogeneousNextTokenChooser:
def __init__(
self,
@ -206,16 +222,72 @@ class HeterogeneousNextTokenChooser:
self.dtype = dtype
self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
def __call__(
self,
input_ids: torch.Tensor,
scores: torch.Tensor,
speculate: int,
speculated_ids: Optional[torch.Tensor] = None,
speculative_scores: Optional[torch.Tensor] = None,
verbose=False
):
if speculated_ids is not None:
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
scores = scores.view(B, S, -1)
else:
B = scores.shape[0]
S = 1
scores = scores.view(B, S, -1)
for warper in self.warpers:
scores = warper(input_ids, scores)
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
for j in range(S):
_scores = scores[:, j]
if self.watermark_processor is not None:
_scores = self.watermark_processor(input_ids, _scores)
if self.repetition_processor is not None:
_scores = self.repetition_processor(input_ids, _scores)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores)
scores[:, j] = _scores
next_ids[:, j] = _next_ids
next_ids = next_ids.view(B*S)
scores = scores.view( B* S, -1)
if speculated_ids is not None:
accepted_ids = []
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
indices = []
for i in range(B):
_next_ids = next_ids[i*S: (i + 1)*S]
_speculated_ids = speculated_ids[i]
validate_speculative = _next_ids[:-1] == _speculated_ids
index = i * S
accepted = 1
# First is always valid
indices.append(index)
for valid in validate_speculative.tolist():
if valid:
index += 1
accepted += 1
indices.append(index)
else:
break
accepted_ids.append(accepted)
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype)
next_ids = next_ids[indices]
scores = scores[indices]
indices = torch.arange(B, device=input_ids.device) * S
if speculative_scores is not None:
speculative_scores = speculative_scores[indices + accepted_ids - 1]
else:
accepted_ids = torch.ones_like(next_ids)
next_ids = self.choice(scores)
# ignore logprobs if we use greedy search
if type(self.choice) == Greedy:
logprobs = torch.empty_like(scores, device="cpu")
@ -224,7 +296,17 @@ class HeterogeneousNextTokenChooser:
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
return next_ids, next_logprobs, logprobs
if speculate > 0:
if speculative_scores is not None:
# Medusa provided some scores
speculative_ids = Greedy()(speculative_scores)
else:
# n-gram
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
else:
speculative_ids = None
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
def filter(self, indices):
if self.watermark_processor is not None: