Merge branch 'main' into support-legacy-completions-api

This commit is contained in:
drbh 2024-02-21 10:48:28 -05:00 committed by GitHub
commit 07ac99a93f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1828 additions and 198 deletions

View File

@ -41,6 +41,10 @@ jobs:
components: rustfmt, clippy components: rustfmt, clippy
- name: Install Protoc - name: Install Protoc
uses: arduino/setup-protoc@v1 uses: arduino/setup-protoc@v1
- name: Clean unused files
run: |
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
- name: Install sccache - name: Install sccache
run: | run: |
curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache

295
Cargo.lock generated
View File

@ -19,12 +19,14 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]] [[package]]
name = "ahash" name = "ahash"
version = "0.8.8" version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff" checksum = "d713b3834d76b85304d4d525563c1276e2e30dc97cc67bfb4585a4a29fc2c89f"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"getrandom",
"once_cell", "once_cell",
"serde",
"version_check", "version_check",
"zerocopy", "zerocopy",
] ]
@ -40,9 +42,9 @@ dependencies = [
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.6.11" version = "0.6.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e2e1ebcb11de5c03c67de28a7df593d32191b44939c482e97702baaaa6ab6a5" checksum = "96b09b5178381e0874812a9b157f7fe84982617e48f71f4e3235482775e5b540"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"anstyle-parse", "anstyle-parse",
@ -88,9 +90,9 @@ dependencies = [
[[package]] [[package]]
name = "anyhow" name = "anyhow"
version = "1.0.79" version = "1.0.80"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1"
[[package]] [[package]]
name = "arc-swap" name = "arc-swap"
@ -128,7 +130,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -139,7 +141,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -265,6 +267,21 @@ version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "bit-set"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
dependencies = [
"bit-vec",
]
[[package]]
name = "bit-vec"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
@ -288,9 +305,9 @@ dependencies = [
[[package]] [[package]]
name = "bumpalo" name = "bumpalo"
version = "3.15.0" version = "3.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d32a994c2b3ca201d9b263612a374263f05e7adde37c4707f693dcd375076d1f" checksum = "c764d619ca78fccbf3069b37bd7af92577f044bb15236036662d79b6559f25b7"
[[package]] [[package]]
name = "bytecount" name = "bytecount"
@ -350,12 +367,9 @@ checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.0.83" version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
@ -365,9 +379,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.0" version = "4.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80c21025abd42669a92efc996ef13cfb2c5c627858421ea58d5c3b331a6c134f" checksum = "c918d541ef2913577a0f9566e9ce27cb35b6df072075769e0b26cb5a554520da"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
@ -375,9 +389,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.5.0" version = "4.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "458bf1f341769dfcf849846f65dffdf9146daa56bcd2a47cb4e1de9915567c99" checksum = "9f3e7391dad68afb0c2ede1bf619f579a3dc9c2ec67f089baa397123a2f3d1eb"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
@ -394,7 +408,7 @@ dependencies = [
"heck", "heck",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -716,6 +730,16 @@ dependencies = [
"cc", "cc",
] ]
[[package]]
name = "fancy-regex"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2"
dependencies = [
"bit-set",
"regex",
]
[[package]] [[package]]
name = "fastrand" name = "fastrand"
version = "2.0.1" version = "2.0.1"
@ -780,6 +804,16 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "fraction"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3027ae1df8d41b4bed2241c8fdad4acc1e7af60c8e17743534b545e77182d678"
dependencies = [
"lazy_static",
"num",
]
[[package]] [[package]]
name = "futures" name = "futures"
version = "0.3.30" version = "0.3.30"
@ -836,7 +870,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -895,8 +929,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"js-sys",
"libc", "libc",
"wasi", "wasi",
"wasm-bindgen",
] ]
[[package]] [[package]]
@ -1181,6 +1217,15 @@ version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
[[package]]
name = "iso8601"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "924e5d73ea28f59011fec52a0d12185d496a9b075d360657aed2a5707f701153"
dependencies = [
"nom",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.10.5" version = "0.10.5"
@ -1223,6 +1268,36 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "jsonschema"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a071f4f7efc9a9118dfb627a0a94ef247986e1ab8606a4c806ae2b3aa3b6978"
dependencies = [
"ahash",
"anyhow",
"base64 0.21.7",
"bytecount",
"clap",
"fancy-regex",
"fraction",
"getrandom",
"iso8601",
"itoa",
"memchr",
"num-cmp",
"once_cell",
"parking_lot",
"percent-encoding",
"regex",
"reqwest",
"serde",
"serde_json",
"time",
"url",
"uuid",
]
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.4.0" version = "1.4.0"
@ -1363,7 +1438,7 @@ checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -1451,7 +1526,7 @@ checksum = "f686d68a09079e63b1d2c64aa305095887ce50565f00a922ebfaeeee0d9ba6ce"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -1574,12 +1649,84 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "num"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
dependencies = [
"num-bigint",
"num-complex",
"num-integer",
"num-iter",
"num-rational",
"num-traits",
]
[[package]]
name = "num-bigint"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-cmp"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa"
[[package]]
name = "num-complex"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "num-conv" name = "num-conv"
version = "0.1.0" version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-iter"
version = "0.1.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
dependencies = [
"autocfg",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]] [[package]]
name = "num-traits" name = "num-traits"
version = "0.2.18" version = "0.2.18"
@ -1654,9 +1801,9 @@ dependencies = [
[[package]] [[package]]
name = "openssl" name = "openssl"
version = "0.10.63" version = "0.10.64"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15c9d69dd87a29568d4d017cfe8ec518706046a05184e5aea92d0af890b803c8" checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f"
dependencies = [ dependencies = [
"bitflags 2.4.2", "bitflags 2.4.2",
"cfg-if", "cfg-if",
@ -1675,7 +1822,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -1686,9 +1833,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]] [[package]]
name = "openssl-sys" name = "openssl-sys"
version = "0.9.99" version = "0.9.100"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22e1bf214306098e4832460f797824c05d25aacdf896f64a985fb0fd992454ae" checksum = "ae94056a791d0e1217d18b6cbdccb02c61e3054fc69893607f4067e3bb0b1fd1"
dependencies = [ dependencies = [
"cc", "cc",
"libc", "libc",
@ -1891,7 +2038,7 @@ checksum = "266c042b60c9c76b8d53061e52b2e0d1116abc57cefc8c5cd671619a56ac3690"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -1937,7 +2084,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a41cf62165e97c7f814d2221421dbb9afcbcdb0a88068e5ea206e19951c2cbb5" checksum = "a41cf62165e97c7f814d2221421dbb9afcbcdb0a88068e5ea206e19951c2cbb5"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -2010,7 +2157,7 @@ dependencies = [
"prost 0.12.3", "prost 0.12.3",
"prost-types", "prost-types",
"regex", "regex",
"syn 2.0.49", "syn 2.0.50",
"tempfile", "tempfile",
"which", "which",
] ]
@ -2038,7 +2185,7 @@ dependencies = [
"itertools 0.11.0", "itertools 0.11.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -2289,16 +2436,17 @@ dependencies = [
[[package]] [[package]]
name = "ring" name = "ring"
version = "0.17.7" version = "0.17.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
dependencies = [ dependencies = [
"cc", "cc",
"cfg-if",
"getrandom", "getrandom",
"libc", "libc",
"spin 0.9.8", "spin 0.9.8",
"untrusted 0.9.0", "untrusted 0.9.0",
"windows-sys 0.48.0", "windows-sys 0.52.0",
] ]
[[package]] [[package]]
@ -2322,7 +2470,7 @@ dependencies = [
"quote", "quote",
"rust-embed-utils", "rust-embed-utils",
"shellexpand", "shellexpand",
"syn 2.0.49", "syn 2.0.50",
"walkdir", "walkdir",
] ]
@ -2383,7 +2531,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41" checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41"
dependencies = [ dependencies = [
"log", "log",
"ring 0.17.7", "ring 0.17.8",
"rustls-pki-types", "rustls-pki-types",
"rustls-webpki", "rustls-webpki",
"subtle", "subtle",
@ -2411,7 +2559,7 @@ version = "0.102.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610"
dependencies = [ dependencies = [
"ring 0.17.7", "ring 0.17.8",
"rustls-pki-types", "rustls-pki-types",
"untrusted 0.9.0", "untrusted 0.9.0",
] ]
@ -2424,9 +2572,9 @@ checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4"
[[package]] [[package]]
name = "ryu" name = "ryu"
version = "1.0.16" version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1"
[[package]] [[package]]
name = "same-file" name = "same-file"
@ -2458,7 +2606,7 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
dependencies = [ dependencies = [
"ring 0.17.7", "ring 0.17.8",
"untrusted 0.9.0", "untrusted 0.9.0",
] ]
@ -2487,38 +2635,38 @@ dependencies = [
[[package]] [[package]]
name = "semver" name = "semver"
version = "1.0.21" version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca"
dependencies = [ dependencies = [
"serde", "serde",
] ]
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.196" version = "1.0.197"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2"
dependencies = [ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.196" version = "1.0.197"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.113" version = "1.0.114"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0"
dependencies = [ dependencies = [
"itoa", "itoa",
"ryu", "ryu",
@ -2701,7 +2849,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustversion", "rustversion",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -2723,9 +2871,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.49" version = "2.0.50"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "915aea9e586f80826ee59f8453c1101f9d1c4b3964cd2460185ee8e299ada496" checksum = "74f1bdc9872430ce9b75da68329d1c1746faf50ffac5f19e02b71e37ff881ffb"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -2811,7 +2959,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "1.4.1" version = "1.4.2"
dependencies = [ dependencies = [
"average", "average",
"clap", "clap",
@ -2832,7 +2980,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "1.4.1" version = "1.4.2"
dependencies = [ dependencies = [
"futures", "futures",
"grpc-metadata", "grpc-metadata",
@ -2848,7 +2996,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "1.4.1" version = "1.4.2"
dependencies = [ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
@ -2864,7 +3012,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "1.4.1" version = "1.4.2"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum", "axum",
@ -2874,6 +3022,7 @@ dependencies = [
"futures-util", "futures-util",
"hf-hub", "hf-hub",
"init-tracing-opentelemetry", "init-tracing-opentelemetry",
"jsonschema",
"metrics", "metrics",
"metrics-exporter-prometheus", "metrics-exporter-prometheus",
"minijinja", "minijinja",
@ -2916,14 +3065,14 @@ checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
name = "thread_local" name = "thread_local"
version = "1.1.7" version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"once_cell", "once_cell",
@ -3082,7 +3231,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -3197,7 +3346,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"prost-build", "prost-build",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -3270,7 +3419,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -3400,9 +3549,9 @@ checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]] [[package]]
name = "unicode-normalization" name = "unicode-normalization"
version = "0.1.22" version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5"
dependencies = [ dependencies = [
"tinyvec", "tinyvec",
] ]
@ -3511,7 +3660,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"regex", "regex",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]
@ -3530,6 +3679,12 @@ dependencies = [
"zip", "zip",
] ]
[[package]]
name = "uuid"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a"
[[package]] [[package]]
name = "valuable" name = "valuable"
version = "0.1.0" version = "0.1.0"
@ -3610,7 +3765,7 @@ dependencies = [
"once_cell", "once_cell",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@ -3644,7 +3799,7 @@ checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
"wasm-bindgen-backend", "wasm-bindgen-backend",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@ -3671,7 +3826,7 @@ version = "0.22.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53"
dependencies = [ dependencies = [
"ring 0.17.7", "ring 0.17.8",
"untrusted 0.9.0", "untrusted 0.9.0",
] ]
@ -3971,7 +4126,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.49", "syn 2.0.50",
] ]
[[package]] [[package]]

View File

@ -9,7 +9,7 @@ members = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "1.4.1" version = "1.4.2"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "1.4.1" "version": "1.4.2"
}, },
"paths": { "paths": {
"/": { "/": {
@ -637,6 +637,35 @@
} }
} }
}, },
"ChatCompletionComplete": {
"type": "object",
"required": [
"index",
"message",
"finish_reason"
],
"properties": {
"finish_reason": {
"type": "string"
},
"index": {
"type": "integer",
"format": "int32",
"minimum": 0
},
"logprobs": {
"allOf": [
{
"$ref": "#/components/schemas/ChatCompletionLogprobs"
}
],
"nullable": true
},
"message": {
"$ref": "#/components/schemas/Message"
}
}
},
"ChatCompletionDelta": { "ChatCompletionDelta": {
"type": "object", "type": "object",
"required": [ "required": [
@ -654,6 +683,59 @@
} }
} }
}, },
"ChatCompletionLogprob": {
"type": "object",
"required": [
"token",
"logprob",
"top_logprobs"
],
"properties": {
"logprob": {
"type": "number",
"format": "float"
},
"token": {
"type": "string"
},
"top_logprobs": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ChatCompletionTopLogprob"
}
}
}
},
"ChatCompletionLogprobs": {
"type": "object",
"required": [
"content"
],
"properties": {
"content": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ChatCompletionLogprob"
}
}
}
},
"ChatCompletionTopLogprob": {
"type": "object",
"required": [
"token",
"logprob"
],
"properties": {
"logprob": {
"type": "number",
"format": "float"
},
"token": {
"type": "string"
}
}
},
"ChatRequest": { "ChatRequest": {
"type": "object", "type": "object",
"required": [ "required": [
@ -1022,6 +1104,49 @@
} }
} }
}, },
"GrammarType": {
"oneOf": [
{
"type": "object",
"required": [
"type",
"value"
],
"properties": {
"type": {
"type": "string",
"enum": [
"json"
]
},
"value": {
"description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions."
}
}
},
{
"type": "object",
"required": [
"type",
"value"
],
"properties": {
"type": {
"type": "string",
"enum": [
"regex"
]
},
"value": {
"type": "string"
}
}
}
],
"discriminator": {
"propertyName": "type"
}
},
"Info": { "Info": {
"type": "object", "type": "object",
"required": [ "required": [
@ -1311,6 +1436,31 @@
"items": { "items": {
"$ref": "#/components/schemas/SimpleToken" "$ref": "#/components/schemas/SimpleToken"
} }
},
"Usage": {
"type": "object",
"required": [
"prompt_tokens",
"completion_tokens",
"total_tokens"
],
"properties": {
"completion_tokens": {
"type": "integer",
"format": "int32",
"minimum": 0
},
"prompt_tokens": {
"type": "integer",
"format": "int32",
"minimum": 0
},
"total_tokens": {
"type": "integer",
"format": "int32",
"minimum": 0
}
}
} }
} }
}, },

View File

@ -40,6 +40,9 @@ class ResponseComparator(JSONSnapshotExtension):
exclude=None, exclude=None,
matcher=None, matcher=None,
): ):
if isinstance(data, Response):
data = data.dict()
if isinstance(data, List): if isinstance(data, List):
data = [d.dict() for d in data] data = [d.dict() for d in data]

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -10.0,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 1736,
"logprob": -2.09375,
"special": false,
"text": " form"
},
{
"id": 109,
"logprob": -1.8671875,
"special": false,
"text": "\n\n"
},
{
"id": 651,
"logprob": -2.4375,
"special": false,
"text": "The"
},
{
"id": 2121,
"logprob": -1.8203125,
"special": false,
"text": " test"
},
{
"id": 3853,
"logprob": -0.23242188,
"special": false,
"text": " request"
},
{
"id": 1736,
"logprob": -0.08544922,
"special": false,
"text": " form"
},
{
"id": 603,
"logprob": -0.9375,
"special": false,
"text": " is"
},
{
"id": 1671,
"logprob": -1.671875,
"special": false,
"text": " used"
},
{
"id": 577,
"logprob": -0.40429688,
"special": false,
"text": " to"
},
{
"id": 3853,
"logprob": -1.1875,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": " form\n\nThe test request form is used to request"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -10.0,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.875,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 7539,
"logprob": -0.73046875,
"special": false,
"text": " forms"
},
{
"id": 708,
"logprob": 0.0,
"special": false,
"text": " are"
},
{
"id": 671,
"logprob": -1.703125,
"special": false,
"text": " an"
},
{
"id": 8727,
"logprob": 0.0,
"special": false,
"text": " essential"
},
{
"id": 1702,
"logprob": 0.0,
"special": false,
"text": " part"
},
{
"id": 576,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 573,
"logprob": 0.0,
"special": false,
"text": " the"
},
{
"id": 11859,
"logprob": -1.6953125,
"special": false,
"text": " lab"
},
{
"id": 2185,
"logprob": -1.3125,
"special": false,
"text": " process"
},
{
"id": 578,
"logprob": -1.5,
"special": false,
"text": " and"
}
],
"top_tokens": null
},
"generated_text": "Test request forms are an essential part of the lab process and"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -10.0,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 1736,
"logprob": -2.09375,
"special": false,
"text": " form"
},
{
"id": 109,
"logprob": -1.9140625,
"special": false,
"text": "\n\n"
},
{
"id": 651,
"logprob": -2.453125,
"special": false,
"text": "The"
},
{
"id": 2121,
"logprob": -1.8984375,
"special": false,
"text": " test"
},
{
"id": 3853,
"logprob": -0.23535156,
"special": false,
"text": " request"
},
{
"id": 1736,
"logprob": -0.091308594,
"special": false,
"text": " form"
},
{
"id": 603,
"logprob": -0.96875,
"special": false,
"text": " is"
},
{
"id": 1671,
"logprob": -1.6484375,
"special": false,
"text": " used"
},
{
"id": 577,
"logprob": -0.43164062,
"special": false,
"text": " to"
},
{
"id": 3853,
"logprob": -1.2421875,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": " form\n\nThe test request form is used to request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -10.0,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 1736,
"logprob": -2.09375,
"special": false,
"text": " form"
},
{
"id": 109,
"logprob": -1.9140625,
"special": false,
"text": "\n\n"
},
{
"id": 651,
"logprob": -2.453125,
"special": false,
"text": "The"
},
{
"id": 2121,
"logprob": -1.8984375,
"special": false,
"text": " test"
},
{
"id": 3853,
"logprob": -0.23535156,
"special": false,
"text": " request"
},
{
"id": 1736,
"logprob": -0.091308594,
"special": false,
"text": " form"
},
{
"id": 603,
"logprob": -0.96875,
"special": false,
"text": " is"
},
{
"id": 1671,
"logprob": -1.6484375,
"special": false,
"text": " used"
},
{
"id": 577,
"logprob": -0.43164062,
"special": false,
"text": " to"
},
{
"id": 3853,
"logprob": -1.2421875,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": " form\n\nThe test request form is used to request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -10.0,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 1736,
"logprob": -2.09375,
"special": false,
"text": " form"
},
{
"id": 109,
"logprob": -1.9140625,
"special": false,
"text": "\n\n"
},
{
"id": 651,
"logprob": -2.453125,
"special": false,
"text": "The"
},
{
"id": 2121,
"logprob": -1.8984375,
"special": false,
"text": " test"
},
{
"id": 3853,
"logprob": -0.23535156,
"special": false,
"text": " request"
},
{
"id": 1736,
"logprob": -0.091308594,
"special": false,
"text": " form"
},
{
"id": 603,
"logprob": -0.96875,
"special": false,
"text": " is"
},
{
"id": 1671,
"logprob": -1.6484375,
"special": false,
"text": " used"
},
{
"id": 577,
"logprob": -0.43164062,
"special": false,
"text": " to"
},
{
"id": 3853,
"logprob": -1.2421875,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": " form\n\nThe test request form is used to request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -10.0,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 1736,
"logprob": -2.09375,
"special": false,
"text": " form"
},
{
"id": 109,
"logprob": -1.9140625,
"special": false,
"text": "\n\n"
},
{
"id": 651,
"logprob": -2.453125,
"special": false,
"text": "The"
},
{
"id": 2121,
"logprob": -1.8984375,
"special": false,
"text": " test"
},
{
"id": 3853,
"logprob": -0.23535156,
"special": false,
"text": " request"
},
{
"id": 1736,
"logprob": -0.091308594,
"special": false,
"text": " form"
},
{
"id": 603,
"logprob": -0.96875,
"special": false,
"text": " is"
},
{
"id": 1671,
"logprob": -1.6484375,
"special": false,
"text": " used"
},
{
"id": 577,
"logprob": -0.43164062,
"special": false,
"text": " to"
},
{
"id": 3853,
"logprob": -1.2421875,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": " form\n\nThe test request form is used to request"
}
]

View File

@ -135,129 +135,129 @@
"special": false, "special": false,
"text": "\",\"" "text": "\",\""
}, },
{
"id": 4230,
"logprob": -0.020492554,
"special": false,
"text": "last"
},
{
"id": 1170,
"logprob": -0.0013818741,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.0067749023,
"special": false,
"text": "\":\""
},
{
"id": 29950,
"logprob": -0.11578369,
"special": false,
"text": "H"
},
{
"id": 14339,
"logprob": -0.004131317,
"special": false,
"text": "olt"
},
{
"id": 29920,
"logprob": -0.0033359528,
"special": false,
"text": "z"
},
{
"id": 3284,
"logprob": -0.20471191,
"special": false,
"text": "\",\""
},
{ {
"id": 29882, "id": 29882,
"logprob": -0.0069274902, "logprob": -0.08862305,
"special": false, "special": false,
"text": "h" "text": "h"
}, },
{ {
"id": 20838, "id": 711,
"logprob": -0.19580078, "logprob": -0.66259766,
"special": false, "special": false,
"text": "obb" "text": "ob"
}, },
{ {
"id": 29891, "id": 1609,
"logprob": -2.2649765e-06, "logprob": -5.51939e-05,
"special": false, "special": false,
"text": "y" "text": "by"
}, },
{ {
"id": 4710, "id": 4710,
"logprob": -0.32080078, "logprob": -0.23120117,
"special": false, "special": false,
"text": "\":\"" "text": "\":\""
}, },
{ {
"id": 29911, "id": 29911,
"logprob": -2.1035156, "logprob": -2.3730469,
"special": false, "special": false,
"text": "T" "text": "T"
}, },
{ {
"id": 11003, "id": 11003,
"logprob": -0.020767212, "logprob": -0.032104492,
"special": false, "special": false,
"text": "rees" "text": "rees"
}, },
{ {
"id": 3284, "id": 3284,
"logprob": -0.6010742, "logprob": -0.22021484,
"special": false,
"text": "\",\""
},
{
"id": 4230,
"logprob": -0.06726074,
"special": false,
"text": "last"
},
{
"id": 1170,
"logprob": -0.003501892,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.0045661926,
"special": false,
"text": "\":\""
},
{
"id": 29950,
"logprob": -0.12512207,
"special": false,
"text": "H"
},
{
"id": 14339,
"logprob": -0.009552002,
"special": false,
"text": "olt"
},
{
"id": 29920,
"logprob": -0.00042438507,
"special": false,
"text": "z"
},
{
"id": 3284,
"logprob": -0.11651611,
"special": false, "special": false,
"text": "\",\"" "text": "\",\""
}, },
{ {
"id": 29876, "id": 29876,
"logprob": -0.57666016, "logprob": -0.29736328,
"special": false, "special": false,
"text": "n" "text": "n"
}, },
{ {
"id": 398, "id": 398,
"logprob": -0.0061073303, "logprob": -0.003030777,
"special": false, "special": false,
"text": "um" "text": "um"
}, },
{ {
"id": 29907, "id": 29907,
"logprob": -0.45703125, "logprob": -0.3774414,
"special": false, "special": false,
"text": "C" "text": "C"
}, },
{ {
"id": 1446, "id": 1446,
"logprob": -0.0002872944, "logprob": -0.0003130436,
"special": false, "special": false,
"text": "ats" "text": "ats"
}, },
{ {
"id": 1115, "id": 1115,
"logprob": -0.0021018982, "logprob": -0.0021514893,
"special": false, "special": false,
"text": "\":" "text": "\":"
}, },
{ {
"id": 29906, "id": 29906,
"logprob": -0.08996582, "logprob": -0.071899414,
"special": false, "special": false,
"text": "2" "text": "2"
}, },
{ {
"id": 29913, "id": 29913,
"logprob": -0.021697998, "logprob": -0.018997192,
"special": false, "special": false,
"text": "}" "text": "}"
}, },
@ -270,5 +270,5 @@
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "{\"firstName\":\"David\",\"lastName\":\"Holtz\",\"hobby\":\"Trees\",\"numCats\":2}" "generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}"
} }

View File

@ -18,7 +18,6 @@ async def flash_llama_awq(flash_llama_awq_handle):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq(flash_llama_awq, response_snapshot): async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate( response = await flash_llama_awq.generate(
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
@ -33,7 +32,6 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate( response = await flash_llama_awq.generate(
"What is Deep Learning?", "What is Deep Learning?",
@ -55,7 +53,6 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot): async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(
flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4 flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4

View File

@ -18,7 +18,6 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot): async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
response = await flash_llama_awq_sharded.generate( response = await flash_llama_awq_sharded.generate(
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
@ -33,7 +32,6 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_load_sharded( async def test_flash_llama_awq_load_sharded(
flash_llama_awq_sharded, generate_load, response_snapshot flash_llama_awq_sharded, generate_load, response_snapshot
): ):

View File

@ -0,0 +1,61 @@
import pytest
@pytest.fixture(scope="module")
def flash_gemma_handle(launcher):
with launcher("gg-hf/gemma-2b", num_shard=1) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_gemma(flash_gemma_handle):
await flash_gemma_handle.health(300)
return flash_gemma_handle.client
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma(flash_gemma, response_snapshot):
response = await flash_gemma.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
response = await flash_gemma.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
responses = await generate_load(flash_gemma, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -14,7 +14,6 @@ async def flash_medusa(flash_medusa_handle):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_simple(flash_medusa, response_snapshot): async def test_flash_medusa_simple(flash_medusa, response_snapshot):
response = await flash_medusa.generate( response = await flash_medusa.generate(
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
@ -25,7 +24,6 @@ async def test_flash_medusa_simple(flash_medusa, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_all_params(flash_medusa, response_snapshot): async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
response = await flash_medusa.generate( response = await flash_medusa.generate(
"What is Deep Learning?", "What is Deep Learning?",
@ -48,7 +46,6 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot): async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4 flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4

View File

@ -14,7 +14,6 @@ async def flash_mistral(flash_mistral_handle):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral(flash_mistral, response_snapshot): async def test_flash_mistral(flash_mistral, response_snapshot):
response = await flash_mistral.generate( response = await flash_mistral.generate(
"Test request", max_new_tokens=10, decoder_input_details=True "Test request", max_new_tokens=10, decoder_input_details=True
@ -26,7 +25,6 @@ async def test_flash_mistral(flash_mistral, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral_all_params(flash_mistral, response_snapshot): async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
response = await flash_mistral.generate( response = await flash_mistral.generate(
"Test request", "Test request",
@ -49,7 +47,6 @@ async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot): async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(
flash_mistral, "Test request", max_new_tokens=10, n=4 flash_mistral, "Test request", max_new_tokens=10, n=4

View File

@ -14,7 +14,6 @@ async def flash_phi(flash_phi_handle):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi(flash_phi, response_snapshot): async def test_flash_phi(flash_phi, response_snapshot):
response = await flash_phi.generate( response = await flash_phi.generate(
"Test request", max_new_tokens=10, decoder_input_details=True "Test request", max_new_tokens=10, decoder_input_details=True
@ -26,7 +25,6 @@ async def test_flash_phi(flash_phi, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_all_params(flash_phi, response_snapshot): async def test_flash_phi_all_params(flash_phi, response_snapshot):
response = await flash_phi.generate( response = await flash_phi.generate(
"Test request", "Test request",
@ -50,7 +48,6 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot): async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4) responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)

View File

@ -14,7 +14,6 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot): async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
response = await flash_starcoder_gptq.generate( response = await flash_starcoder_gptq.generate(
"def geometric_mean(L: List[float]):", "def geometric_mean(L: List[float]):",
@ -26,7 +25,6 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_gptq_default_params( async def test_flash_starcoder_gptq_default_params(
flash_starcoder_gptq, generous_response_snapshot flash_starcoder_gptq, generous_response_snapshot
): ):
@ -43,7 +41,6 @@ async def test_flash_starcoder_gptq_default_params(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_gptq_load( async def test_flash_starcoder_gptq_load(
flash_starcoder_gptq, generate_load, generous_response_snapshot flash_starcoder_gptq, generate_load, generous_response_snapshot
): ):

View File

@ -19,7 +19,6 @@ async def flash_llama_grammar(flash_llama_grammar_handle):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot): async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate( response = await flash_llama_grammar.generate(
"Test request", max_new_tokens=10, decoder_input_details=True "Test request", max_new_tokens=10, decoder_input_details=True
@ -30,7 +29,6 @@ async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot): async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate( response = await flash_llama_grammar.generate(
"Whats Googles DNS", "Whats Googles DNS",
@ -49,7 +47,6 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate( response = await flash_llama_grammar.generate(
"info: david holtz like trees and has two cats. ", "info: david holtz like trees and has two cats. ",
@ -92,13 +89,12 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
assert response.details.generated_tokens == 30 assert response.details.generated_tokens == 30
assert ( assert (
response.generated_text response.generated_text
== '{"firstName":"David","lastName":"Holtz","hobby":"Trees","numCats":2}' == '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
) )
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_load( async def test_flash_llama_grammar_load(
flash_llama_grammar, generate_load, response_snapshot flash_llama_grammar, generate_load, response_snapshot
): ):
@ -130,7 +126,6 @@ async def test_flash_llama_grammar_load(
# this is the same as the above test, but only fires off a single request # this is the same as the above test, but only fires off a single request
# this is only to ensure that the parallel and single inference produce the same result # this is only to ensure that the parallel and single inference produce the same result
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_single_load_instance( async def test_flash_llama_grammar_single_load_instance(
flash_llama_grammar, generate_load, response_snapshot flash_llama_grammar, generate_load, response_snapshot
): ):

View File

@ -14,7 +14,6 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_mamba(fused_kernel_mamba, response_snapshot): async def test_mamba(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate( response = await fused_kernel_mamba.generate(
"What is Deep Learning?", max_new_tokens=10 "What is Deep Learning?", max_new_tokens=10
@ -26,7 +25,6 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate( response = await fused_kernel_mamba.generate(
"blue, red, yellow, ", "blue, red, yellow, ",
@ -53,7 +51,6 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private
async def test_mamba_load( async def test_mamba_load(
fused_kernel_mamba, generate_load, generous_response_snapshot fused_kernel_mamba, generate_load, generous_response_snapshot
): ):

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-integration-tests" name = "text-generation-integration-tests"
version = "1.4.1" version = "1.4.2"
description = "Text Generation Inference integration tests" description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"] authors = ["Nicolas Patry <nicolas@huggingface.co>"]

View File

@ -22,6 +22,7 @@ text-generation-client = { path = "client" }
clap = { version = "4.4.5", features = ["derive", "env"] } clap = { version = "4.4.5", features = ["derive", "env"] }
futures = "0.3.28" futures = "0.3.28"
hf-hub = { version = "0.3.0", features = ["tokio"] } hf-hub = { version = "0.3.0", features = ["tokio"] }
jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.21.1" metrics = "0.21.1"
metrics-exporter-prometheus = { version = "0.12.1", features = [] } metrics-exporter-prometheus = { version = "0.12.1", features = [] }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"

View File

@ -65,39 +65,16 @@ impl HubTokenizerConfig {
} }
} }
mod json_object_or_string_to_string {
use serde::{Deserialize, Deserializer};
use serde_json::Value;
// A custom deserializer that treats both strings and objects as strings.
// This provides flexibility with input formats for the 'grammar' field.
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(s),
// Safely handle serialization and return an error if it fails
Value::Object(o) => {
serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string()))
}
_ => Err(serde::de::Error::custom(
"expected string or object for grammar",
)),
}
}
}
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]
#[serde(tag = "type", content = "value")] #[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType { pub(crate) enum GrammarType {
#[serde( /// A string that represents a [JSON Schema](https://json-schema.org/).
rename = "json", ///
deserialize_with = "json_object_or_string_to_string::deserialize" /// JSON Schema is a declarative language that allows to annotate JSON documents
)] /// with types and descriptions.
Json(String), #[serde(rename = "json")]
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
Json(serde_json::Value),
#[serde(rename = "regex")] #[serde(rename = "regex")]
Regex(String), Regex(String),
} }
@ -416,7 +393,7 @@ pub(crate) struct ChatCompletionTopLogprob {
logprob: f32, logprob: f32,
} }
#[derive(Clone, Deserialize, Serialize, Default)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
pub(crate) struct Usage { pub(crate) struct Usage {
pub prompt_tokens: u32, pub prompt_tokens: u32,
pub completion_tokens: u32, pub completion_tokens: u32,

View File

@ -270,7 +270,7 @@ async fn main() -> Result<(), RouterError> {
let compat_return_full_text = match &model_info.pipeline_tag { let compat_return_full_text = match &model_info.pipeline_tag {
None => { None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}"); tracing::warn!("no pipeline tag found for model {tokenizer_name}");
false true
} }
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
}; };

View File

@ -1076,11 +1076,16 @@ pub async fn run(
Info, Info,
CompatGenerateRequest, CompatGenerateRequest,
GenerateRequest, GenerateRequest,
GrammarType,
ChatRequest, ChatRequest,
Message, Message,
ChatCompletionComplete,
ChatCompletionChoice, ChatCompletionChoice,
ChatCompletionDelta, ChatCompletionDelta,
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletionLogprob,
ChatCompletionLogprobs,
ChatCompletionTopLogprob,
ChatCompletion, ChatCompletion,
CompletionRequest, CompletionRequest,
CompletionComplete, CompletionComplete,
@ -1098,6 +1103,7 @@ pub async fn run(
StreamDetails, StreamDetails,
ErrorResponse, ErrorResponse,
GrammarType, GrammarType,
Usage,
) )
), ),
tags( tags(

View File

@ -1,7 +1,9 @@
/// Payload validation logic /// Payload validation logic
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest, GrammarType}; use crate::{GenerateParameters, GenerateRequest, GrammarType};
use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value;
use text_generation_client::{ use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
}; };
@ -313,8 +315,29 @@ impl Validation {
return Err(ValidationError::Grammar); return Err(ValidationError::Grammar);
} }
match grammar { match grammar {
// currently both are handled the same way since compilation is done in Python GrammarType::Json(json) => {
GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()), let json = match json {
// if value is a string, we need to parse it again to make sure its
// a valid json
Value::String(s) => serde_json::from_str(&s)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string())),
Value::Object(_) => Ok(json),
_ => Err(ValidationError::Grammar),
}?;
// Check if the json is a valid JSONSchema
JSONSchema::options()
.with_draft(Draft::Draft202012)
.compile(&json)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
(
// Serialize json to string
serde_json::to_string(&json)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
ProtoGrammarType::Json.into(),
)
}
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
} }
} }
@ -486,6 +509,8 @@ pub enum ValidationError {
Tokenizer(String), Tokenizer(String),
#[error("grammar is not supported")] #[error("grammar is not supported")]
Grammar, Grammar,
#[error("grammar is not valid: {0}")]
InvalidGrammar(String),
} }
#[cfg(test)] #[cfg(test)]

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-server" name = "text-generation-server"
version = "1.4.1" version = "1.4.2"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -52,6 +52,9 @@ try:
from text_generation_server.models.flash_llama import ( from text_generation_server.models.flash_llama import (
FlashLlama, FlashLlama,
) )
from text_generation_server.models.flash_gemma import (
FlashGemma,
)
from text_generation_server.models.flash_santacoder import ( from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded, FlashSantacoderSharded,
) )
@ -312,6 +315,28 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "gemma":
if FLASH_ATTENTION:
return FlashGemma(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate")
)
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]: if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
if sharded: if sharded:

View File

@ -0,0 +1,609 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
import os
from shutil import copyfile
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from tokenizers import processors
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.utils import logging
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
FastRMSNorm,
)
GemmaTokenizer = None
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "tokenizer.model",
"tokenizer_file": "tokenizer.json",
}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
},
"tokenizer_file": {
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
},
}
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
# fmt: off
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on
class GemmaTokenizerFast(PreTrainedTokenizerFast):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
slow_tokenizer_class = GemmaTokenizer
padding_side = "left"
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
clean_up_tokenization_spaces=False,
unk_token="<unk>",
bos_token="<bos>",
eos_token="<eos>",
pad_token="<pad>",
add_bos_token=True,
add_eos_token=False,
use_default_system_prompt=False,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
use_default_system_prompt=use_default_system_prompt,
**kwargs,
)
self._add_bos_token = add_bos_token
self._add_eos_token = add_eos_token
self.update_post_processor()
self.use_default_system_prompt = use_default_system_prompt
self.vocab_file = vocab_file
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def update_post_processor(self):
"""
Updates the underlying post processor with the current `bos_token` and `eos_token`.
"""
bos = self.bos_token
bos_token_id = self.bos_token_id
if bos is None and self.add_bos_token:
raise ValueError("add_bos_token = True but bos_token = None")
eos = self.eos_token
eos_token_id = self.eos_token_id
if eos is None and self.add_eos_token:
raise ValueError("add_eos_token = True but eos_token = None")
single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
special_tokens = []
if self.add_bos_token:
special_tokens.append((bos, bos_token_id))
if self.add_eos_token:
special_tokens.append((eos, eos_token_id))
self._tokenizer.post_processor = processors.TemplateProcessing(
single=single, pair=pair, special_tokens=special_tokens
)
@property
def add_eos_token(self):
return self._add_eos_token
@property
def add_bos_token(self):
return self._add_bos_token
@add_eos_token.setter
def add_eos_token(self, value):
self._add_eos_token = value
self.update_post_processor()
@add_bos_token.setter
def add_bos_token(self, value):
self._add_bos_token = value
self.update_post_processor()
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "")
+ VOCAB_FILES_NAMES["vocab_file"],
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
@property
def default_chat_template(self):
raise NotImplementedError
# TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = bos_token_id + token_ids_0 + eos_token_id
if token_ids_1 is not None:
output = output + bos_token_id + token_ids_1 + eos_token_id
return output
class GemmaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=256128,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.head_dim = head_dim
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class GemmaFastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix, weights, eps=1e-6):
weight = weights.get_tensor(f"{prefix}.weight") + 1
return cls(weight, eps)
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
class FlashGemmaAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class GemmaMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
class FlashGemmaLayer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = GemmaFastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = GemmaFastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output)
return mlp_output, attn_res
class FlashGemmaModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
embed_norm = config.hidden_size**0.5
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.layers = nn.ModuleList(
[
FlashGemmaLayer(
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = GemmaFastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.model = FlashGemmaModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits

View File

@ -0,0 +1,104 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
GemmaTokenizerFast,
FlashGemmaForCausalLM,
GemmaConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashGemma(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGemma is only available on GPU")
tokenizer = GemmaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
config = GemmaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision)
model = FlashGemmaForCausalLM(config, weights)
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
import os
from pathlib import Path
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -328,7 +328,6 @@ class HeterogeneousNextTokenChooser:
scores = scores.view(B, S, -1) scores = scores.view(B, S, -1)
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
mask = torch.full((scores.shape[-1],), -math.inf, device=self.device)
for j in range(S): for j in range(S):
_scores = scores[:, j] _scores = scores[:, j]
@ -338,10 +337,10 @@ class HeterogeneousNextTokenChooser:
_scores = self.repetition_processor(input_ids, _scores) _scores = self.repetition_processor(input_ids, _scores)
if self.frequency_processor is not None: if self.frequency_processor is not None:
_scores = self.frequency_processor(input_ids, _scores) _scores = self.frequency_processor(input_ids, _scores)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
if self.grammar_processor is not None: if self.grammar_processor is not None:
_scores = self.grammar_processor(_scores, self.fsm_grammar_states) _scores = self.grammar_processor(_scores, self.fsm_grammar_states)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores) _next_ids = self.choice(_scores)
scores[:, j] = _scores scores[:, j] = _scores
next_ids[:, j] = _next_ids next_ids[:, j] = _next_ids