Merge branch 'main' into cpu_perf

This commit is contained in:
Wang, Yi 2024-10-15 21:50:15 +08:00 committed by GitHub
commit fb4d2080af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
122 changed files with 8272 additions and 2261 deletions

View File

@ -4,3 +4,4 @@ server/transformers
server/flash-attention
cmake-build-debug/
cmake-build-release/
Dockerfile*

View File

@ -21,9 +21,11 @@ jobs:
build-and-push:
outputs:
docker_image: ${{ steps.final.outputs.docker_image }}
docker_volume: ${{ steps.final.outputs.docker_volume }}
docker_devices: ${{ steps.final.outputs.docker_devices }}
runs_on: ${{ steps.final.outputs.runs_on }}
label: ${{ steps.final.outputs.label }}
extra_pytest: ${{ steps.final.outputs.extra_pytest }}
concurrency:
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
@ -44,32 +46,39 @@ jobs:
cuda)
export dockerfile="Dockerfile"
export label_extension=""
export docker_volume="/mnt/cache"
export docker_devices=""
export runs_on="aws-g6-12xl-plus-priv-cache"
export platform=""
export extra_pytest=""
;;
rocm)
export dockerfile="Dockerfile_amd"
export label_extension="-rocm"
export docker_devices="/dev/kfd,/dev/dri"
# TODO Re-enable when they pass.
# export runs_on="amd-gpu-tgi"
export runs_on="ubuntu-latest"
export docker_volume="/mnt"
export runs_on="amd-gpu-runners"
export platform=""
export extra_pytest="-k test_flash_gemma_gptq_load"
;;
intel-xpu)
export dockerfile="Dockerfile_intel"
export label_extension="-intel-xpu"
export docker_devices=""
export docker_volume="/mnt/cache"
export runs_on="ubuntu-latest"
export platform="xpu"
export extra_pytest=""
;;
intel-cpu)
export dockerfile="Dockerfile_intel"
export label_extension="-intel-cpu"
export docker_devices=""
export runs_on="ubuntu-latest"
export docker_devices="none"
export docker_volume="/mnt/cache"
# export runs_on="ubuntu-latest"
export runs_on="aws-highmemory-32-plus-priv"
export platform="cpu"
export extra_pytest="-k test_flash_gemma_simple"
;;
esac
echo $dockerfile
@ -81,8 +90,10 @@ jobs:
echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV
echo "LABEL=${label_extension}" >> $GITHUB_ENV
echo "PLATFORM=${platform}" >> $GITHUB_ENV
echo "DOCKER_VOLUME=${docker_volume}" >> $GITHUB_ENV
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v3
@ -157,16 +168,18 @@ jobs:
run: |
echo "docker_image=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
echo "extra_pytest=${{ env.EXTRA_PYTEST }}" >> "$GITHUB_OUTPUT"
integration_tests:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
needs: build-and-push
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
runs-on:
group: ${{ needs.build-and-push.outputs.runs_on }}
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env:
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
steps:
@ -177,15 +190,16 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
- name: Install
run: |
make install-integration-tests
- name: Run tests
run: |
export DOCKER_VOLUME=/mnt/cache
export DOCKER_VOLUME=${{ needs.build-and-push.outputs.docker_volume }}
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
export EXTRA_PYTEST="${{ needs.build-and-push.outputs.extra_pytest }}"
export HF_TOKEN=${{ secrets.HF_TOKEN }}
echo $DOCKER_IMAGE
pytest -s -vv integration-tests ${PYTEST_FLAGS}
pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}

196
Cargo.lock generated
View File

@ -133,7 +133,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -172,7 +172,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -183,7 +183,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -205,9 +205,9 @@ dependencies = [
[[package]]
name = "autocfg"
version = "1.3.0"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "av1-grain"
@ -316,12 +316,12 @@ dependencies = [
[[package]]
name = "axum"
version = "0.7.6"
version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f43644eed690f5374f1af436ecd6aea01cd201f6fbdf0178adaf6907afb2cec"
checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae"
dependencies = [
"async-trait",
"axum-core 0.4.4",
"axum-core 0.4.5",
"bytes",
"futures-util",
"http 1.1.0",
@ -367,9 +367,9 @@ dependencies = [
[[package]]
name = "axum-core"
version = "0.4.4"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e6b8ba012a258d63c9adfa28b9ddcf66149da6f986c5b5452e629d5ee64bf00"
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
dependencies = [
"async-trait",
"bytes",
@ -392,7 +392,7 @@ version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08"
dependencies = [
"axum 0.7.6",
"axum 0.7.7",
"futures-core",
"futures-util",
"http 1.1.0",
@ -456,7 +456,7 @@ dependencies = [
"regex",
"rustc-hash",
"shlex",
"syn 2.0.77",
"syn 2.0.79",
"which",
]
@ -605,9 +605,9 @@ dependencies = [
[[package]]
name = "cc"
version = "1.1.21"
version = "1.1.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0"
checksum = "9540e661f81799159abee814118cc139a2004b3a3aa3ea37724a1b66530b90e0"
dependencies = [
"jobserver",
"libc",
@ -704,7 +704,7 @@ dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -971,7 +971,7 @@ dependencies = [
"proc-macro2",
"quote",
"scratch",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -988,7 +988,7 @@ checksum = "98532a60dedaebc4848cb2cba5023337cc9ea3af16a5b062633fabfd9f18fb60"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -1012,7 +1012,7 @@ dependencies = [
"proc-macro2",
"quote",
"strsim",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -1023,7 +1023,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
dependencies = [
"darling_core",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -1053,7 +1053,7 @@ dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -1063,7 +1063,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc"
dependencies = [
"derive_builder_core",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -1192,9 +1192,9 @@ checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
[[package]]
name = "fdeflate"
version = "0.3.4"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645"
checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab"
dependencies = [
"simd-adler32",
]
@ -1207,9 +1207,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]]
name = "flate2"
version = "1.0.33"
version = "1.0.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253"
checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0"
dependencies = [
"crc32fast",
"miniz_oxide 0.8.0",
@ -1338,7 +1338,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -1864,7 +1864,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b23a0c8dfe501baac4adf6ebbfa6eddf8f0c07f56b058cc1288017e32397846c"
dependencies = [
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -1884,7 +1884,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -2270,7 +2270,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08"
dependencies = [
"adler",
"simd-adler32",
]
[[package]]
@ -2280,6 +2279,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1"
dependencies = [
"adler2",
"simd-adler32",
]
[[package]]
@ -2319,7 +2319,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -2519,7 +2519,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -2599,9 +2599,12 @@ dependencies = [
[[package]]
name = "once_cell"
version = "1.19.0"
version = "1.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
checksum = "82881c4be219ab5faaf2ad5e5e5ecdff8c66bd7402ca3160975c93b24961afd1"
dependencies = [
"portable-atomic",
]
[[package]]
name = "onig"
@ -2654,7 +2657,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -2808,7 +2811,7 @@ dependencies = [
"glob",
"once_cell",
"opentelemetry 0.21.0",
"ordered-float 4.2.2",
"ordered-float 4.3.0",
"percent-encoding",
"rand",
"thiserror",
@ -2828,7 +2831,7 @@ dependencies = [
"lazy_static",
"once_cell",
"opentelemetry 0.23.0",
"ordered-float 4.2.2",
"ordered-float 4.3.0",
"percent-encoding",
"rand",
"thiserror",
@ -2851,9 +2854,9 @@ dependencies = [
[[package]]
name = "ordered-float"
version = "4.2.2"
version = "4.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a91171844676f8c7990ce64959210cd2eaef32c2612c50f9fae9f8aaa6065a6"
checksum = "44d501f1a72f71d3c063a6bbc8f7271fa73aa09fe5d6283b6571e2ed176a2537"
dependencies = [
"num-traits",
]
@ -2937,7 +2940,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -2988,22 +2991,22 @@ dependencies = [
[[package]]
name = "png"
version = "0.17.13"
version = "0.17.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1"
checksum = "52f9d46a34a05a6a57566bc2bfae066ef07585a6e3fa30fbbdff5936380623f0"
dependencies = [
"bitflags 1.3.2",
"crc32fast",
"fdeflate",
"flate2",
"miniz_oxide 0.7.4",
"miniz_oxide 0.8.0",
]
[[package]]
name = "portable-atomic"
version = "1.8.0"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d30538d42559de6b034bc76fd6dd4c38961b1ee5c6c56e3808c50128fdbc22ce"
checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2"
[[package]]
name = "powerfmt"
@ -3027,7 +3030,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba"
dependencies = [
"proc-macro2",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -3079,7 +3082,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
dependencies = [
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -3119,7 +3122,7 @@ dependencies = [
"prost 0.12.6",
"prost-types",
"regex",
"syn 2.0.77",
"syn 2.0.79",
"tempfile",
]
@ -3146,7 +3149,7 @@ dependencies = [
"itertools 0.12.1",
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -3205,7 +3208,7 @@ dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -3218,7 +3221,7 @@ dependencies = [
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -3402,9 +3405,9 @@ dependencies = [
[[package]]
name = "redox_syscall"
version = "0.5.5"
version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62871f2d65009c0256aed1b9cfeeb8ac272833c404e13d53d400cd0dad7a2ac0"
checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f"
dependencies = [
"bitflags 2.6.0",
]
@ -3422,14 +3425,14 @@ dependencies = [
[[package]]
name = "regex"
version = "1.10.6"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata 0.4.7",
"regex-syntax 0.8.4",
"regex-automata 0.4.8",
"regex-syntax 0.8.5",
]
[[package]]
@ -3443,13 +3446,13 @@ dependencies = [
[[package]]
name = "regex-automata"
version = "0.4.7"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax 0.8.4",
"regex-syntax 0.8.5",
]
[[package]]
@ -3460,9 +3463,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.8.4"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
@ -3563,7 +3566,7 @@ dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
"syn 2.0.77",
"syn 2.0.79",
"walkdir",
]
@ -3686,9 +3689,9 @@ dependencies = [
[[package]]
name = "rustls-pki-types"
version = "1.8.0"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0"
checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55"
[[package]]
name = "rustls-webpki"
@ -3813,7 +3816,7 @@ checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -3840,9 +3843,9 @@ dependencies = [
[[package]]
name = "serde_spanned"
version = "0.6.7"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d"
checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1"
dependencies = [
"serde",
]
@ -4028,7 +4031,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -4050,9 +4053,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.77"
version = "2.0.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed"
checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590"
dependencies = [
"proc-macro2",
"quote",
@ -4152,9 +4155,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]]
name = "tempfile"
version = "3.12.0"
version = "3.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64"
checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b"
dependencies = [
"cfg-if",
"fastrand",
@ -4174,7 +4177,7 @@ dependencies = [
[[package]]
name = "text-generation-backends-trtllm"
version = "2.3.1-dev0"
version = "2.3.2-dev0"
dependencies = [
"async-stream",
"async-trait",
@ -4197,7 +4200,7 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
version = "2.3.1-dev0"
version = "2.3.2-dev0"
dependencies = [
"average",
"clap 4.5.18",
@ -4217,7 +4220,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "2.3.1-dev0"
version = "2.3.2-dev0"
dependencies = [
"async-trait",
"base64 0.22.1",
@ -4235,7 +4238,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "2.3.1-dev0"
version = "2.3.2-dev0"
dependencies = [
"clap 4.5.18",
"ctrlc",
@ -4244,6 +4247,7 @@ dependencies = [
"nix 0.28.0",
"once_cell",
"pyo3",
"regex",
"reqwest",
"serde",
"serde_json",
@ -4255,11 +4259,11 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "2.3.1-dev0"
version = "2.3.2-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.6",
"axum 0.7.7",
"axum-tracing-opentelemetry",
"base64 0.22.1",
"clap 4.5.18",
@ -4304,11 +4308,11 @@ dependencies = [
[[package]]
name = "text-generation-router-v2"
version = "2.3.1-dev0"
version = "2.3.2-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.6",
"axum 0.7.7",
"axum-tracing-opentelemetry",
"base64 0.22.1",
"clap 4.5.18",
@ -4353,11 +4357,11 @@ dependencies = [
[[package]]
name = "text-generation-router-v3"
version = "2.3.1-dev0"
version = "2.3.2-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.6",
"axum 0.7.7",
"axum-tracing-opentelemetry",
"base64 0.22.1",
"clap 4.5.18",
@ -4428,7 +4432,7 @@ checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -4533,7 +4537,7 @@ dependencies = [
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.8.4",
"regex-syntax 0.8.5",
"serde",
"serde_json",
"spm_precompiled",
@ -4566,7 +4570,7 @@ dependencies = [
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.8.4",
"regex-syntax 0.8.5",
"serde",
"serde_json",
"spm_precompiled",
@ -4612,7 +4616,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -4771,7 +4775,7 @@ dependencies = [
"proc-macro2",
"prost-build",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -4858,7 +4862,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -5151,7 +5155,7 @@ dependencies = [
"proc-macro2",
"quote",
"regex",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -5160,7 +5164,7 @@ version = "6.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da"
dependencies = [
"axum 0.7.6",
"axum 0.7.7",
"mime_guess",
"regex",
"rust-embed",
@ -5189,7 +5193,7 @@ checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
@ -5290,7 +5294,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
"wasm-bindgen-shared",
]
@ -5324,7 +5328,7 @@ checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@ -5668,9 +5672,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.6.19"
version = "0.6.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c52ac009d615e79296318c1bcce2d422aaca15ad08515e344feeda07df67a587"
checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b"
dependencies = [
"memchr",
]
@ -5703,7 +5707,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]

View File

@ -20,7 +20,7 @@ default-members = [
resolver = "2"
[workspace.package]
version = "2.3.1-dev0"
version = "2.3.2-dev0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -1,5 +1,5 @@
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -32,6 +32,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA
ARG DOCKER_LABEL
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
@ -39,7 +40,7 @@ COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt
RUN cargo build --profile release-opt --frozen
# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile

24
Dockerfile.nix Normal file
View File

@ -0,0 +1,24 @@
# Build the image and get out the docker file:
#
# docker build -t tgi-nix-builder -f Dockerfile.nix
# docker run --log-driver=none tgi-nix-builder | docker load
FROM nixos/nix:2.18.8 AS builder
RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf
RUN nix profile install nixpkgs#cachix
RUN cachix use text-generation-inference
WORKDIR /root
ADD . .
RUN nix build .
RUN mkdir /tmp/nix-store-closure
RUN cp -R $(nix-store -qR result/) /tmp/nix-store-closure
FROM ubuntu:24.04
WORKDIR /app
# Copy /nix/store
COPY --from=builder /tmp/nix-store-closure /nix/store
COPY --from=builder /root/result /app
RUN ldconfig
CMD ["ldconfig", "/app/bin/text-generation-launcher"]

View File

@ -1,5 +1,5 @@
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -31,6 +31,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA
ARG DOCKER_LABEL
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
@ -38,10 +39,10 @@ COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt
RUN cargo build --profile release-opt --frozen
# Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
FROM rocm/dev-ubuntu-22.04:6.2 AS base
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
@ -50,33 +51,34 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \
git \
make \
libmsgpack-dev \
libssl-dev \
llvm-dev \
g++ \
# Needed to build VLLM & flash.
rocthrust-dev \
hipsparse-dev \
hipblas-dev \
hipblaslt-dev \
hipcub-dev \
rocblas-dev \
hiprand-dev \
hipfft-dev \
rocrand-dev \
miopen-hip-dev \
hipfft-dev \
hipcub-dev \
hipsolver-dev \
rccl-dev \
cmake \
python3.11-dev && \
python3.11-venv && \
rm -rf /var/lib/apt/lists/*
# Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1
ARG PYTORCH_VERSION='2.3.0'
ARG ROCM_VERSION='6.0.2'
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
ENV PATH=/opt/conda/bin:$PATH
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
@ -100,41 +102,132 @@ RUN case ${TARGETPLATFORM} in \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
# Install flash-attention, torch dependencies
RUN pip install numpy einops ninja --no-cache-dir
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
RUN pip uninstall -y triton && \
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
cd triton/python && \
pip install .
RUN conda install mkl=2021
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
ARG _GLIBCXX_USE_CXX11_ABI="1"
ARG CMAKE_PREFIX_PATH="/opt/conda"
ARG COMMON_WORKDIR=/
WORKDIR ${COMMON_WORKDIR}
# Install HIPBLASLt
FROM base AS build_hipblaslt
ARG HIPBLASLT_BRANCH="e6da924"
RUN git clone https://github.com/ROCm/hipBLASLt.git \
&& cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
&& cd build/release \
&& make package
FROM scratch AS export_hipblaslt
ARG COMMON_WORKDIR
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
# RCCL build stages
FROM base AS build_rccl
ARG RCCL_BRANCH="rocm-6.2.0"
RUN git clone https://github.com/ROCm/rccl \
&& cd rccl \
&& git checkout ${RCCL_BRANCH} \
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
FROM scratch AS export_rccl
ARG COMMON_WORKDIR
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
# Triton build stages
FROM base AS build_triton
ARG TRITON_BRANCH="e192dba"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
&& cd triton \
&& git checkout ${TRITON_BRANCH} \
&& cd python \
&& python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch AS export_triton
ARG COMMON_WORKDIR
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
# # AMD-SMI build stages
FROM base AS build_amdsmi
RUN cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=dist
FROM scratch AS export_amdsmi
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
FROM base as build_pytorch
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
if ls /install/*.deb; then \
dpkg -i /install/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
fi
ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
ARG BUILD_CAFFE2="0" \
BUILD_CAFFE2_OPS="0" \
USE_CUDA="0" \
USE_ROCM="1" \
BUILD_TEST="0" \
USE_FBGEMM="0" \
USE_NNPACK="0" \
USE_QNNPACK="0" \
USE_XNNPACK="0" \
USE_FLASH_ATTENTION="1" \
USE_MEM_EFF_ATTENTION="0"
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
# A commit to fix the output scaling factor issue in _scaled_mm
# Not yet in 2.5.0-rc1
ARG PYTORCH_BRANCH="cedc116"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
ENV HIP_FORCE_DEV_KERNARG=1
RUN git clone ${PYTORCH_REPO} pytorch \
&& cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
&& pip install -r requirements.txt --no-cache-dir \
&& python tools/amd_build/build_amd.py \
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch as export_pytorch
ARG COMMON_WORKDIR
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
# However, Triton requires a tunning for each prompt length, which is prohibitive.
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
FROM base AS install_deps
FROM base AS kernel-builder
ARG COMMON_WORKDIR
# Install hipblaslt
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
if ls /install/*.deb; then \
dpkg -i /install/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
fi
RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
if ls /install/*.deb; then \
dpkg -i /install/*.deb \
# RCCL needs to be installed twice
&& dpkg -i /install/*.deb \
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
fi
RUN --mount=type=bind,from=export_triton,src=/,target=/install \
if ls /install/*.whl; then \
# Preemptively uninstall to prevent pip same-version no-installs
pip uninstall -y triton \
&& pip install /install/*.whl; \
fi
RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
# Preemptively uninstall to prevent pip same-version no-installs
pip uninstall -y amdsmi \
&& pip install /install/*.whl;
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
if ls /install/*.whl; then \
# Preemptively uninstall to prevent pip same-version no-installs
pip uninstall -y torch torchvision \
&& pip install /install/*.whl; \
fi
FROM install_deps AS kernel-builder
# # Build vllm kernels
FROM kernel-builder AS vllm-builder
@ -174,7 +267,7 @@ COPY server/exllamav2_kernels/ .
RUN python setup.py build
FROM base AS base-copy
FROM install_deps AS base-copy
# Text Generation Inference base env
ENV HF_HOME=/data \
@ -224,6 +317,19 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base-copy
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
ENV HIP_FORCE_DEV_KERNARG=1
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
# However, Triton requires a tunning for each prompt length, which is prohibitive.
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV ROCM_USE_SKINNY_GEMM=1
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

View File

@ -1,6 +1,6 @@
ARG PLATFORM=xpu
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -32,6 +32,7 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA
ARG DOCKER_LABEL
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
@ -39,7 +40,7 @@ COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt
RUN cargo build --profile release-opt --frozen
# Text Generation Inference base image for Intel
@ -52,7 +53,7 @@ ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
ENV PATH=/opt/conda/bin:$PATH
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
@ -111,6 +112,8 @@ ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/
ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
ENV TORCH_LLM_ALLREDUCE=1
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
@ -127,12 +130,22 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \
ca-certificates \
make \
g++ \
g++-12 \
gcc-12 \
git \
wget \
cmake \
libnuma-dev
RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 12
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 12
RUN update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 30
RUN update-alternatives --set cc /usr/bin/gcc
RUN update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 30
RUN update-alternatives --set c++ /usr/bin/g++
ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
@ -164,16 +177,17 @@ RUN case ${TARGETPLATFORM} in \
RUN conda install -c conda-forge gperftools mkl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install triton py-libnuma
WORKDIR /usr/src
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install

View File

@ -83,7 +83,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.3.0 --model-id $model
ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model
```
And then you can make requests like
@ -120,7 +120,7 @@ curl localhost:3000/v1/chat/completions \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0-rocm --model-id $model` instead of the command above.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
```
@ -150,7 +150,7 @@ model=meta-llama/Meta-Llama-3.1-8B-Instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model
```
### A note on Shared Memory (shm)

View File

@ -100,6 +100,7 @@ pub async fn connect_backend(
.map_err(V3Error::Warmup)?,
)?;
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
let backend_info = BackendInfo {
waiting_served_ratio,

View File

@ -27,3 +27,6 @@ asyncio_mode = "auto"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.isort]
profile = "black"

View File

@ -28,11 +28,17 @@ class ToolCall(BaseModel):
function: dict
class Chunk(BaseModel):
type: str
text: Optional[str] = None
image_url: Any = None
class Message(BaseModel):
# Role of the message sender
role: str
# Content of the message
content: Optional[str] = None
content: Optional[Union[str, List[Chunk]]] = None
# Optional name of the message sender
name: Optional[str] = None
# Tool calls associated with the chat completion

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "2.3.1-dev0"
"version": "2.3.2-dev0"
},
"paths": {
"/": {
@ -2114,12 +2114,18 @@
"ToolType": {
"oneOf": [
{
"type": "object",
"default": null,
"nullable": true
"type": "string",
"description": "Means the model can pick between generating a message or calling one or more tools.",
"enum": [
"auto"
]
},
{
"type": "string"
"type": "string",
"description": "Means the model will not call any tool and instead generates a message.",
"enum": [
"none"
]
},
{
"type": "object",
@ -2131,13 +2137,10 @@
"$ref": "#/components/schemas/FunctionName"
}
}
},
{
"type": "object",
"default": null,
"nullable": true
}
]
],
"description": "Controls which (if any) tool is called by the model.",
"example": "auto"
},
"Url": {
"type": "object",

View File

@ -3,6 +3,8 @@
title: Text Generation Inference
- local: quicktour
title: Quick Tour
- local: supported_models
title: Supported Models
- local: installation_nvidia
title: Using TGI with Nvidia GPUs
- local: installation_amd
@ -15,8 +17,7 @@
title: Using TGI with Intel GPUs
- local: installation
title: Installation from source
- local: supported_models
title: Supported Models and Hardware
- local: architecture
title: Internal Architecture
- local: usage_statistics

View File

@ -10,7 +10,7 @@ This diagram shows well there are these separate components:
- **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server.
- **The model server**, responsible of receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent.
- **The launcher** is a helper thar will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
- **The launcher** is a helper that will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
The router and the model server can be two different machines, they do not need to be deployed together.

View File

@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \
-e HF_TOKEN=$token \
-p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 \
--model-id $model
```

View File

@ -36,6 +36,12 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
```
To specify model revision, use `adapter_id@revision`, as follows:
```bash
LORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2
```
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
```bash

View File

@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize bitsandbytes
```
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize bitsandbytes-nf4
```
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize gptq
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.1 --model-id $model --quantize gptq
```
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.

View File

@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.3.0-rocm \
ghcr.io/huggingface/text-generation-inference:2.3.1-rocm \
--model-id $model
```
@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
## Custom PagedAttention
For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.
The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.
## Unsupported features
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:

View File

@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.3.0-intel-xpu \
ghcr.io/huggingface/text-generation-inference:2.3.1-intel-xpu \
--model-id $model --cuda-graphs 0
```
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.3.0-intel-cpu \
ghcr.io/huggingface/text-generation-inference:2.3.1-intel-cpu \
--model-id $model --cuda-graphs 0
```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.3.0 \
ghcr.io/huggingface/text-generation-inference:2.3.1 \
--model-id $model
```

View File

@ -11,14 +11,13 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.3.0 \
ghcr.io/huggingface/text-generation-inference:2.3.1 \
--model-id $model
```
<Tip>
If you want to serve gated or private models, which provide
controlled access to sensitive or proprietary content, refer to
If you want to serve gated or private models, please refer to
[this guide](https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/gated_model_access)
for detailed instructions.
@ -97,7 +96,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash
docker run ghcr.io/huggingface/text-generation-inference:2.2.0 --help
docker run ghcr.io/huggingface/text-generation-inference:2.3.1 --help
```
</Tip>

View File

@ -89,6 +89,15 @@ Options:
[env: DTYPE=]
[possible values: float16, bfloat16]
```
## KV_CACHE_DTYPE
```shell
--kv-cache-dtype <KV_CACHE_DTYPE>
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
[env: KV_CACHE_DTYPE=]
[possible values: fp8_e5m2]
```
## TRUST_REMOTE_CODE
```shell

View File

@ -1,9 +1,7 @@
# Supported Models and Hardware
# Supported Models
Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models (VLMs & LLMs) are supported.
## Supported Models
Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
@ -20,6 +18,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1)
- [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder)
- [Phi](https://huggingface.co/microsoft/phi-1_5)
- [PhiMoe](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct)
- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)
@ -34,6 +33,8 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b)
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal)
- [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) (Multimodal)
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:

View File

@ -497,11 +497,11 @@
"systems": "systems_7"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"lastModified": 1726560853,
"narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
"type": "github"
},
"original": {
@ -718,11 +718,11 @@
},
"nixpkgs_6": {
"locked": {
"lastModified": 1724915739,
"narHash": "sha256-7PgRge4mn5akFvhPwefuaLQGbF5BnmxlwZJEf7CgbrE=",
"lastModified": 1727675176,
"narHash": "sha256-xIjBFMYldWvj+g8ahxMPofsj+OqxvKJN6YylNHQ7gn4=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "85be051bb60943d3328d91aaf2598798f87e19af",
"rev": "a6d0207fea9212d28cd3d487efe6bc699663b93a",
"type": "github"
},
"original": {
@ -853,11 +853,11 @@
]
},
"locked": {
"lastModified": 1726626348,
"narHash": "sha256-sYV7e1B1yLcxo8/h+/hTwzZYmaju2oObNiy5iRI0C30=",
"lastModified": 1727836133,
"narHash": "sha256-JE0zciM5IGWvK8J/pE2VldNBf7oyMH5WrU8tZArefbg=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "6fd52ad8bd88f39efb2c999cc971921c2fb9f3a2",
"rev": "02321540b0c8000b36889b1b974d1fec585b25a4",
"type": "github"
},
"original": {
@ -978,11 +978,11 @@
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1727353315,
"narHash": "sha256-yZovq/6P8Z199r7e+NbTXyCqRgK6grRkLxYHWHnHckI=",
"lastModified": 1728381423,
"narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "1d42c4125ebafb87707118168995675cc5050b9d",
"rev": "93123736c97e9f7bfe825bfaf3d7de0fc9a21a1e",
"type": "github"
},
"original": {

View File

@ -37,6 +37,7 @@
overlays = [
rust-overlay.overlays.default
tgi-nix.overlays.default
(import nix/overlay.nix)
];
};
crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };
@ -141,7 +142,8 @@
};
};
packages.default = pkgs.writeShellApplication {
packages = rec {
default = pkgs.writeShellApplication {
name = "text-generation-inference";
runtimeInputs = [
server
@ -151,6 +153,16 @@
${launcher}/bin/text-generation-launcher "$@"
'';
};
dockerImage = pkgs.callPackage nix/docker.nix {
text-generation-inference = default;
};
dockerImageStreamed = pkgs.callPackage nix/docker.nix {
text-generation-inference = default;
stream = true;
};
};
}
);
}

View File

@ -336,6 +336,7 @@ def launcher(event_loop):
use_flash_attention: bool = True,
disable_grammar_support: bool = False,
dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
@ -375,6 +376,9 @@ def launcher(event_loop):
if dtype is not None:
args.append("--dtype")
args.append(dtype)
if kv_cache_dtype is not None:
args.append("--kv-cache-dtype")
args.append(kv_cache_dtype)
if revision is not None:
args.append("--revision")
args.append(revision)
@ -434,6 +438,7 @@ def launcher(event_loop):
use_flash_attention: bool = True,
disable_grammar_support: bool = False,
dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
@ -456,6 +461,9 @@ def launcher(event_loop):
if dtype is not None:
args.append("--dtype")
args.append(dtype)
if kv_cache_dtype is not None:
args.append("--kv-cache-dtype")
args.append(kv_cache_dtype)
if revision is not None:
args.append("--revision")
args.append(revision)
@ -484,6 +492,7 @@ def launcher(event_loop):
try:
container = client.containers.get(container_name)
container.stop()
container.remove()
container.wait()
except NotFound:
pass
@ -506,13 +515,28 @@ def launcher(event_loop):
volumes = [f"{DOCKER_VOLUME}:/data"]
if DOCKER_DEVICES:
devices = DOCKER_DEVICES.split(",")
if DOCKER_DEVICES.lower() == "none":
devices = []
else:
devices = DOCKER_DEVICES.strip().split(",")
visible = os.getenv("ROCR_VISIBLE_DEVICES")
if visible:
env["ROCR_VISIBLE_DEVICES"] = visible
device_requests = []
if not devices:
devices = None
elif devices == ["nvidia.com/gpu=all"]:
devices = None
device_requests = [
docker.types.DeviceRequest(
driver="cdi",
# count=gpu_count,
device_ids=[f"nvidia.com/gpu={i}"],
)
for i in range(gpu_count)
]
else:
devices = []
devices = None
device_requests = [
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
]
@ -532,6 +556,7 @@ def launcher(event_loop):
shm_size="1G",
)
try:
yield ContainerLauncherHandle(client, container.name, port)
if not use_flash_attention:
@ -546,7 +571,11 @@ def launcher(event_loop):
container_output = container.logs().decode("utf-8")
print(container_output, file=sys.stderr)
finally:
try:
container.remove()
except Exception:
pass
if DOCKER_IMAGE is not None:
return docker_launcher
@ -589,7 +618,6 @@ def generate_multi():
max_new_tokens: int,
seed: Optional[int] = None,
) -> List[Response]:
import numpy as np
arange = np.arange(len(prompts))

View File

@ -0,0 +1,104 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.079956055,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.2763672,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37548828,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4628906,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02885437,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.2565918,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0063438416,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3056641,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.6035156,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
}

View File

@ -0,0 +1,57 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 3,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 374,
"logprob": -22.96875,
"text": " is"
},
{
"id": 5655,
"logprob": -10.71875,
"text": " deep"
},
{
"id": 6975,
"logprob": -2.6992188,
"text": " learning"
},
{
"id": 30,
"logprob": -4.8398438,
"text": "?"
}
],
"seed": 0,
"tokens": [
{
"id": 720,
"logprob": -0.4411621,
"special": false,
"text": " \n"
},
{
"id": 220,
"logprob": -0.35864258,
"special": false,
"text": " "
},
{
"id": 128001,
"logprob": 0.0,
"special": true,
"text": "<|end_of_text|>"
}
],
"top_tokens": null
},
"generated_text": "What is deep learning? \n "
}

View File

@ -0,0 +1,418 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
}
]

View File

@ -0,0 +1,104 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -12.296875,
"text": "What"
},
{
"id": 349,
"logprob": -0.97216797,
"text": "is"
},
{
"id": 3534,
"logprob": -10.1796875,
"text": "deep"
},
{
"id": 5168,
"logprob": -0.9658203,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.44384766,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.50878906,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.8876953,
"special": false,
"text": "\n"
},
{
"id": 23229,
"logprob": -0.15124512,
"special": false,
"text": "Deep"
},
{
"id": 5168,
"logprob": -0.030288696,
"special": false,
"text": " learning"
},
{
"id": 349,
"logprob": -0.16687012,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.17858887,
"special": false,
"text": " a"
},
{
"id": 19804,
"logprob": -0.8046875,
"special": false,
"text": " subset"
},
{
"id": 302,
"logprob": -0.007205963,
"special": false,
"text": " of"
},
{
"id": 5599,
"logprob": -0.090026855,
"special": false,
"text": " machine"
},
{
"id": 5168,
"logprob": -0.0030670166,
"special": false,
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a subset of machine learning"
}

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 349,
"logprob": -13.921875,
"text": "is"
},
{
"id": 3534,
"logprob": -11.2265625,
"text": "deep"
},
{
"id": 5168,
"logprob": -2.3886719,
"text": "learning"
},
{
"id": 28804,
"logprob": -4.7109375,
"text": "?"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 23229,
"logprob": -0.5229492,
"special": false,
"text": "Deep"
},
{
"id": 17504,
"logprob": 0.0,
"special": false,
"text": " Learning"
},
{
"id": 349,
"logprob": -0.5151367,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 19804,
"logprob": 0.0,
"special": false,
"text": " subset"
},
{
"id": 302,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 13253,
"logprob": -1.3359375,
"special": false,
"text": " Machine"
},
{
"id": 17504,
"logprob": 0.0,
"special": false,
"text": " Learning"
},
{
"id": 28725,
"logprob": 0.0,
"special": false,
"text": ","
}
],
"top_tokens": null
},
"generated_text": "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
}

View File

@ -0,0 +1,418 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -12.296875,
"text": "What"
},
{
"id": 349,
"logprob": -0.97216797,
"text": "is"
},
{
"id": 3534,
"logprob": -10.1796875,
"text": "deep"
},
{
"id": 5168,
"logprob": -0.9658203,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.44384766,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.50878906,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.8876953,
"special": false,
"text": "\n"
},
{
"id": 23229,
"logprob": -0.15136719,
"special": false,
"text": "Deep"
},
{
"id": 5168,
"logprob": -0.030273438,
"special": false,
"text": " learning"
},
{
"id": 349,
"logprob": -0.1665039,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.1776123,
"special": false,
"text": " a"
},
{
"id": 19804,
"logprob": -0.8076172,
"special": false,
"text": " subset"
},
{
"id": 302,
"logprob": -0.007183075,
"special": false,
"text": " of"
},
{
"id": 5599,
"logprob": -0.090148926,
"special": false,
"text": " machine"
},
{
"id": 5168,
"logprob": -0.0030670166,
"special": false,
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a subset of machine learning"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -12.34375,
"text": "What"
},
{
"id": 349,
"logprob": -0.96728516,
"text": "is"
},
{
"id": 3534,
"logprob": -10.1796875,
"text": "deep"
},
{
"id": 5168,
"logprob": -0.97265625,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.44189453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.51220703,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.87402344,
"special": false,
"text": "\n"
},
{
"id": 23229,
"logprob": -0.15039062,
"special": false,
"text": "Deep"
},
{
"id": 5168,
"logprob": -0.030288696,
"special": false,
"text": " learning"
},
{
"id": 349,
"logprob": -0.1652832,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.17858887,
"special": false,
"text": " a"
},
{
"id": 19804,
"logprob": -0.81103516,
"special": false,
"text": " subset"
},
{
"id": 302,
"logprob": -0.007183075,
"special": false,
"text": " of"
},
{
"id": 5599,
"logprob": -0.08880615,
"special": false,
"text": " machine"
},
{
"id": 5168,
"logprob": -0.0030612946,
"special": false,
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a subset of machine learning"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -12.34375,
"text": "What"
},
{
"id": 349,
"logprob": -0.96728516,
"text": "is"
},
{
"id": 3534,
"logprob": -10.1796875,
"text": "deep"
},
{
"id": 5168,
"logprob": -0.97265625,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.44189453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.51220703,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.87402344,
"special": false,
"text": "\n"
},
{
"id": 23229,
"logprob": -0.15039062,
"special": false,
"text": "Deep"
},
{
"id": 5168,
"logprob": -0.030288696,
"special": false,
"text": " learning"
},
{
"id": 349,
"logprob": -0.1652832,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.17858887,
"special": false,
"text": " a"
},
{
"id": 19804,
"logprob": -0.81103516,
"special": false,
"text": " subset"
},
{
"id": 302,
"logprob": -0.007183075,
"special": false,
"text": " of"
},
{
"id": 5599,
"logprob": -0.08880615,
"special": false,
"text": " machine"
},
{
"id": 5168,
"logprob": -0.0030612946,
"special": false,
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a subset of machine learning"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -12.34375,
"text": "What"
},
{
"id": 349,
"logprob": -0.96728516,
"text": "is"
},
{
"id": 3534,
"logprob": -10.1796875,
"text": "deep"
},
{
"id": 5168,
"logprob": -0.97265625,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.44189453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.51220703,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.87402344,
"special": false,
"text": "\n"
},
{
"id": 23229,
"logprob": -0.15039062,
"special": false,
"text": "Deep"
},
{
"id": 5168,
"logprob": -0.030288696,
"special": false,
"text": " learning"
},
{
"id": 349,
"logprob": -0.1652832,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.17858887,
"special": false,
"text": " a"
},
{
"id": 19804,
"logprob": -0.81103516,
"special": false,
"text": " subset"
},
{
"id": 302,
"logprob": -0.007183075,
"special": false,
"text": " of"
},
{
"id": 5599,
"logprob": -0.08880615,
"special": false,
"text": " machine"
},
{
"id": 5168,
"logprob": -0.0030612946,
"special": false,
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a subset of machine learning"
}
]

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8808594,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37280273,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.26098633,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017137527,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2695312,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9238281,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48828125,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -0.34838867,
"special": false,
"text": "\n"
},
{
"id": 13940,
"logprob": -0.38916016,
"special": false,
"text": "``"
},
{
"id": 28832,
"logprob": 0.0,
"special": false,
"text": "`"
},
{
"id": 3371,
"logprob": -1.2529297,
"special": false,
"text": "json"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 28751,
"logprob": 0.0,
"special": false,
"text": "{"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 2287,
"logprob": 0.0,
"special": false,
"text": " "
},
{
"id": 345,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 3134,
"logprob": -0.640625,
"special": false,
"text": "request"
}
],
"top_tokens": null
},
"generated_text": "Test request\n```json\n{\n \"request"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8828125,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37329102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.2602539,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017185211,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2753906,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9316406,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48217773,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8828125,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37329102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.2602539,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017185211,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2753906,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9316406,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48217773,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8828125,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37329102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.2602539,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017185211,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2753906,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9316406,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48217773,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8828125,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37329102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.2602539,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017185211,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2753906,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9316406,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48217773,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
}
]

View File

@ -0,0 +1,109 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7133789,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9296875,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.048919678,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8105469,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.017028809,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0027313232,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0623207e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5361328,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17578125,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011539459,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.47436523,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027680397,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
}

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 16030,
"logprob": null,
"text": "gradient"
},
{
"id": 26815,
"logprob": -6.4960938,
"text": "descent"
},
{
"id": 29973,
"logprob": -5.1484375,
"text": "?"
},
{
"id": 13,
"logprob": -4.0351562,
"text": "\n"
},
{
"id": 13,
"logprob": -5.2265625,
"text": "\n"
}
],
"seed": 0,
"tokens": [
{
"id": 10994,
"logprob": -1.1542969,
"special": false,
"text": "Hello"
},
{
"id": 29991,
"logprob": 0.0,
"special": false,
"text": "!"
},
{
"id": 739,
"logprob": 0.0,
"special": false,
"text": " It"
},
{
"id": 2444,
"logprob": -0.42260742,
"special": false,
"text": " seems"
},
{
"id": 366,
"logprob": 0.0,
"special": false,
"text": " you"
},
{
"id": 29915,
"logprob": 0.0,
"special": false,
"text": "'"
},
{
"id": 276,
"logprob": -0.9838867,
"special": false,
"text": "re"
},
{
"id": 3211,
"logprob": 0.0,
"special": false,
"text": " address"
},
{
"id": 292,
"logprob": 0.0,
"special": false,
"text": "ing"
},
{
"id": 263,
"logprob": -0.15124512,
"special": false,
"text": " a"
}
],
"top_tokens": null
},
"generated_text": "What is gradient descent?\n\nHello! It seems you're addressing a"
}

View File

@ -0,0 +1,438 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7133789,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9296875,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.048919678,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8105469,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.017028809,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0028476715,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023971558,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.000116467476,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.47436523,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027871132,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7128906,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.05053711,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0058594,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.018859863,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.002822876,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.0001155138,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.47436523,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027036667,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.71484375,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.049346924,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"text": "\n"
},
{
"id": 13,
"logprob": -0.86328125,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.017196655,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0028438568,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.026558e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011622906,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.48608398,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027894974,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1724,
"logprob": null,
"text": "What"
},
{
"id": 338,
"logprob": -0.7192383,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.050445557,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"text": "\n"
},
{
"id": 13,
"logprob": -0.8276367,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 25584,
"logprob": -0.01727295,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0027542114,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011301041,
"special": false,
"text": "order"
},
{
"id": 13883,
"logprob": -0.48608398,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027894974,
"special": false,
"text": " algorithm"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
}
]

View File

@ -0,0 +1,106 @@
[
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a bustling city, a chicken named Cluck",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a world where even chickens could dream big,",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a world where even chickens could dream big,",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a world where even chickens could dream big,",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
}
]

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a bustling city, a chicken named Cluck",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727556016,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
}

View File

@ -1,38 +1,26 @@
{
"choices": [
{
"finish_reason": "eos_token",
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"content": "I am an AI assistant",
"name": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": null,
"name": "notify_error"
},
"id": 0,
"type": "function"
}
]
"tool_calls": null
},
"usage": null
}
],
"created": 1712852597,
"created": 1728497062,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.5-native",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.2-dev0-native",
"usage": {
"completion_tokens": 39,
"prompt_tokens": 496,
"total_tokens": 535
"completion_tokens": 23,
"prompt_tokens": 604,
"total_tokens": 627
}
}

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": " assistant",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1728497531,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": " fans",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1728497461,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}

View File

@ -16,7 +16,7 @@ async def flash_gemma(flash_gemma_handle):
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma(flash_gemma, response_snapshot):
async def test_flash_gemma_simple(flash_gemma, response_snapshot):
response = await flash_gemma.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)

View File

@ -15,7 +15,7 @@ async def flash_llama(flash_llama_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama(flash_llama, response_snapshot):
async def test_flash_llama_simple(flash_llama, response_snapshot):
response = await flash_llama.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)

View File

@ -0,0 +1,77 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_fp8_kv_cache_handle(launcher):
with launcher(
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache_handle):
await flash_llama_fp8_kv_cache_handle.health(300)
return flash_llama_fp8_kv_cache_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snapshot):
response = await flash_llama_fp8_kv_cache.generate(
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
)
assert (
response.generated_text
== " Deep learning is a subset of machine learning that is"
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_kv_cache_all_params(
flash_llama_fp8_kv_cache, response_snapshot
):
response = await flash_llama_fp8_kv_cache.generate(
"What is deep learning?",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_kv_cache_load(
flash_llama_fp8_kv_cache, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_fp8_kv_cache, "What is deep learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert (
responses[0].generated_text
== " Deep learning is a subset of machine learning that is"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"Different messages : {[r.generated_text for r in responses]}"
assert responses == response_snapshot

View File

@ -0,0 +1,73 @@
import pytest
@pytest.fixture(scope="module")
def flash_mixtral_awq_handle(launcher):
with launcher("casperhansen/mixtral-instruct-awq", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_mixtral_awq(flash_mixtral_awq_handle):
await flash_mixtral_awq_handle.health(300)
return flash_mixtral_awq_handle.client
@pytest.mark.asyncio
async def test_flash_mixtral_awq(flash_mixtral_awq, response_snapshot):
response = await flash_mixtral_awq.generate(
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert (
response.generated_text == "\n\nDeep learning is a subset of machine learning"
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_mixtral_awq_all_params(flash_mixtral_awq, response_snapshot):
response = await flash_mixtral_awq.generate(
"What is deep learning?",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_mixtral_awq_load(
flash_mixtral_awq, generate_load, response_snapshot
):
responses = await generate_load(
flash_mixtral_awq, "What is deep learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert responses[0].details.generated_tokens == 10
assert (
responses[0].generated_text
== "\n\nDeep learning is a subset of machine learning"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses == response_snapshot

View File

@ -0,0 +1,60 @@
import pytest
@pytest.fixture(scope="module")
def flash_mixtral_gptq_handle(launcher):
with launcher("TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_mixtral_gptq(flash_mixtral_gptq_handle):
await flash_mixtral_gptq_handle.health(300)
return flash_mixtral_gptq_handle.client
@pytest.mark.asyncio
async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
response = await flash_mixtral_gptq.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot):
response = await flash_mixtral_gptq.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_mixtral_gptq_load(
flash_mixtral_gptq, generate_load, response_snapshot
):
responses = await generate_load(
flash_mixtral_gptq, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses == response_snapshot

View File

@ -0,0 +1,75 @@
import pytest
@pytest.fixture(scope="module")
def flash_phi35_moe_handle(launcher):
with launcher(
"microsoft/Phi-3.5-MoE-instruct",
num_shard=4,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_phi35_moe(flash_phi35_moe_handle):
await flash_phi35_moe_handle.health(300)
return flash_phi35_moe_handle.client
@pytest.mark.asyncio
async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
response = await flash_phi35_moe.generate(
"What is gradient descent?\n\n", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "Gradient descent is a first-order optimization algorithm"
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
response = await flash_phi35_moe.generate(
"What is gradient descent?\n\n",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is gradient descent?\n\nHello! It seems you're addressing a"
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_snapshot):
responses = await generate_load(
flash_phi35_moe, "What is gradient descent?\n\n", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert responses[0].details.generated_tokens == 10
assert (
responses[0].generated_text
== "Gradient descent is a first-order optimization algorithm"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses == response_snapshot

View File

@ -0,0 +1,105 @@
import pytest
import base64
import asyncio
@pytest.fixture(scope="module")
def mllama_handle(launcher):
with launcher("meta-llama/Llama-3.2-11B-Vision-Instruct", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def mllama(mllama_handle):
await mllama_handle.health(300)
return mllama_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio
async def test_mllama_simpl(mllama, response_snapshot):
# chicken = get_chicken()
response = await mllama.chat(
max_tokens=10,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Can you tell me a very short story based on the image?",
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
},
},
],
},
],
)
assert response.usage == {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60,
}
assert (
response.choices[0].message.content
== "In a bustling city, a chicken named Cluck"
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_mllama_load(mllama, generate_load, response_snapshot):
futures = [
mllama.chat(
max_tokens=10,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Can you tell me a very short story based on the image?",
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
},
},
],
},
],
)
for i in range(4)
]
responses = await asyncio.gather(*futures)
generated_texts = [response.choices[0].message.content for response in responses]
assert generated_texts[0] == "In a bustling city, a chicken named Cluck"
assert len(generated_texts) == 4
assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts]
)
assert responses == response_snapshot

View File

@ -4,7 +4,9 @@ import pytest
@pytest.fixture(scope="module")
def flash_llama_grammar_tools_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
"meta-llama/Meta-Llama-3.1-8B-Instruct",
num_shard=2,
disable_grammar_support=False,
) as handle:
yield handle
@ -205,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream(
)
count = 0
tool_calls_generated = ""
last_response = None
async for response in responses:
count += 1
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
last_response = response
assert response.choices[0].delta.content is None
assert count == 48
assert response == response_snapshot
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
)
assert count == 28
assert last_response == response_snapshot
@pytest.mark.asyncio
@ -225,18 +236,94 @@ async def test_flash_llama_grammar_tools_insufficient_information(
messages=[
{
"role": "system",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
"content": "You're a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "Who are you?",
},
],
stream=False,
)
assert responses.choices[0].message.tool_calls is None
assert responses.choices[0].message.content == "I am an AI assistant"
assert responses == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "Who are you?",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 5
assert content_generated == "I am an AI assistant"
assert last_response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=False,
stream=True,
)
assert responses.choices[0].message.content is None
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 62
assert (
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
content_generated
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
)
assert responses == response_snapshot
assert last_response == response_snapshot

View File

@ -13,3 +13,6 @@ pytest = "^7.4.0"
pytest-asyncio = "^0.21.1"
docker = "^7"
numpy = "^1.20"
[tool.isort]
profile = "black"

View File

@ -18,6 +18,7 @@ serde_json = "1.0.107"
thiserror = "1.0.59"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
regex = "1.11.0"
[dev-dependencies]
float_eq = "1.0.1"

View File

@ -1,9 +1,4 @@
use std::sync::LazyLock;
pub static COMPUTE_CAPABILITY: LazyLock<Option<(usize, usize)>> =
LazyLock::new(get_cuda_capability);
fn get_cuda_capability() -> Option<(usize, usize)> {
pub fn get_cuda_capability() -> Option<(usize, usize)> {
use pyo3::prelude::*;
let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {

View File

@ -5,6 +5,7 @@ use hf_hub::{
};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use regex::Regex;
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
@ -66,7 +67,7 @@ fn get_config(
}
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = *gpu::COMPUTE_CAPABILITY;
let compute_capability = gpu::get_cuda_capability();
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
@ -300,6 +301,22 @@ impl std::fmt::Display for Dtype {
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum KVCacheDtype {
#[clap(name = "fp8_e5m2")]
Fp8e5m2,
}
impl std::fmt::Display for KVCacheDtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KVCacheDtype::Fp8e5m2 => {
write!(f, "fp8_e5m2")
}
}
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
Linear,
@ -401,6 +418,12 @@ struct Args {
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,
/// Specify the dtype for the key-value cache. When this option is not provided,
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
/// the only supported value is `fp8_e5m2` on CUDA.
#[clap(long, env, value_enum)]
kv_cache_dtype: Option<KVCacheDtype>,
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision.
@ -669,6 +692,7 @@ fn shard_manager(
quantize: Option<Quantization>,
speculate: Option<usize>,
dtype: Option<Dtype>,
kv_cache_dtype: Option<KVCacheDtype>,
trust_remote_code: bool,
uds_path: String,
rank: usize,
@ -742,6 +766,11 @@ fn shard_manager(
shard_args.push(dtype.to_string())
}
if let Some(kv_cache_dtype) = kv_cache_dtype {
shard_args.push("--kv-cache-dtype".to_string());
shard_args.push(kv_cache_dtype.to_string())
}
// Model optional revision
if let Some(revision) = revision {
shard_args.push("--revision".to_string());
@ -915,6 +944,7 @@ fn shard_manager(
}
});
// We read stdin in another thread as it seems that lines() can block in some cases
if LevelFilter::current() >= tracing::Level::DEBUG {
thread::spawn(move || {
let mut stdin = io::stdin(); // We get `Stdin` here.
loop {
@ -922,12 +952,11 @@ fn shard_manager(
if let Ok(n) = stdin.read(&mut buffer) {
if n > 0 {
let _ = pstdin.write_all(&buffer[..n]);
} else {
break;
}
}
}
});
}
let mut ready = false;
let start_time = Instant::now();
@ -1302,6 +1331,7 @@ fn spawn_shards(
let otlp_service_name = args.otlp_service_name.clone();
let speculate = args.speculate;
let dtype = args.dtype;
let kv_cache_dtype = args.kv_cache_dtype;
let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels;
@ -1320,6 +1350,7 @@ fn spawn_shards(
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
uds_path,
rank,
@ -1812,14 +1843,37 @@ fn main() -> Result<(), LauncherError> {
if adapter.contains('=') {
continue;
}
let adapter = adapter.trim();
// check if adapter has more than 1 '@'
if adapter.matches('@').count() > 1 {
return Err(LauncherError::ArgumentValidation(format!(
"Invalid LoRA adapter format: {}",
adapter
)));
}
// capture adapter_id, path, revision in format of adapter_id=path@revision
let re = Regex::new(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$").unwrap();
if let Some(caps) = re.captures(adapter) {
let adapter_id = caps.get(1).map_or("", |m| m.as_str());
let revision = caps.get(3).map(|m| m.as_str());
download_convert_model(
adapter,
None,
adapter_id,
revision,
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
} else {
return Err(LauncherError::ArgumentValidation(format!(
"Invalid LoRA adapter format: {}",
adapter
)));
}
}
}

23
nix/docker.nix Normal file
View File

@ -0,0 +1,23 @@
{
dockerTools,
cacert,
text-generation-inference,
stream ? false,
}:
let
build = if stream then dockerTools.streamLayeredImage else dockerTools.buildLayeredImage;
in
build {
name = "tgi-docker";
tag = "latest";
config = {
EntryPoint = [ "${text-generation-inference}/bin/text-generation-inference" ];
Env = [
"HF_HOME=/data"
"PORT=80"
];
};
contents = [ cacert ];
}

View File

@ -1,5 +1,7 @@
{
mkShell,
black,
isort,
openssl,
pkg-config,
protobuf,
@ -14,6 +16,8 @@
mkShell {
buildInputs =
[
black
isort
openssl.dev
pkg-config
(rust-bin.stable.latest.default.override {

41
nix/overlay.nix Normal file
View File

@ -0,0 +1,41 @@
final: prev: {
# You can use this overlay to temporarily override packages for
# development. For permanent overrides, it's better to do this in
# our package flake:
#
# https://github.com/huggingface/text-generation-inference-nix
#
# Note that overriding packages that are in the transitive closure
# of many other packages (e.g. transformers) will require a large
# rebuild.
pythonPackagesExtensions = prev.pythonPackagesExtensions ++ [
(
python-self: python-super: with python-self; {
# Python package override example:
# transformers = python-super.transformers.overrideAttrs (
# _: _: {
# src = final.fetchFromGitHub {
# owner = "huggingface";
# repo = "transformers";
# rev = "2bd4d5897dc73e8b172832070a6f9e567a0df017";
# hash = "sha256-JOIpKH9ssDEfI2Tf15e0iPKtThJwQ9GxMvRAnm+M2Pg=";
# };
# }
# );
}
)
];
# Non-python package override example:
#
# ripgrep = prev.ripgrep.overrideAttrs (
# _: _: {
# src = final.fetchFromGitHub {
# owner = "BurntSushi";
# repo = "ripgrep";
# rev = "79cbe89deb1151e703f4d91b19af9cdcc128b765";
# hash = "sha256-JPTM2KNmGMb+/jOfK3X7OM1wnN+3TU35SJOIcqmp3mg=";
# };
# });
}

View File

@ -146,6 +146,7 @@ pub enum Config {
ClipVisionModel(ClipVisionModel),
Mistral,
Idefics,
Mllama,
Idefics2(Idefics2),
Ssm,
GptBigcode,
@ -159,6 +160,7 @@ pub enum Config {
#[serde(rename = "phi-msft")]
PhiMsft,
Phi3,
PhiMoe,
Llama,
Baichuan,
Paligemma(Paligemma),

View File

@ -29,7 +29,7 @@ impl ChatTemplate {
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
tracing::debug!("Loading template: {:#?}", template_str);
tracing::debug!("Loading template: {}", template_str);
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)

View File

@ -355,6 +355,8 @@ pub enum InferError {
MissingTemplateVariable(String),
#[error("Tool error: {0}")]
ToolError(String),
#[error("Stream event serialization error")]
StreamSerializationError(String),
}
impl InferError {
@ -368,6 +370,7 @@ impl InferError {
InferError::TemplateError(_) => "template_error",
InferError::MissingTemplateVariable(_) => "missing_template_variable",
InferError::ToolError(_) => "tool_error",
InferError::StreamSerializationError(_) => "stream_serialization_error",
}
}
}

View File

@ -31,32 +31,29 @@ impl ToolGrammar {
let mut tools = tools.clone();
// add the notify_error function to the tools
let notify_error = Tool {
// add the no_tool function to the tools
let no_tool = Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "notify_error".to_string(),
description: Some("Notify an error or issue".to_string()),
name: "no_tool".to_string(),
description: Some("Open ened response with no specific tool selected".to_string()),
arguments: json!({
"type": "object",
"properties": {
"error": {
"content": {
"type": "string",
"description": "The error or issue to notify"
"description": "The response content",
}
},
"required": ["error"]
"required": ["content"]
}),
},
};
tools.push(notify_error);
tools.push(no_tool);
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![Self::find_tool_by_name(&tools, &name)?]
}
ToolType::Function { function } => {
ToolType::Function(function) => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools.clone(),

View File

@ -957,12 +957,18 @@ pub fn default_tool_prompt() -> String {
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
#[serde(untagged)]
#[schema(example = "auto")]
/// Controls which (if any) tool is called by the model.
pub enum ToolType {
/// Means the model can pick between generating a message or calling one or more tools.
#[schema(rename = "auto")]
OneOf,
FunctionName(String),
Function { function: FunctionName },
/// Means the model will not call any tool and instead generates a message.
#[schema(rename = "none")]
NoTool,
/// Forces the model to call a specific tool.
#[schema(rename = "function")]
Function(FunctionName),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
@ -977,6 +983,7 @@ pub struct ToolChoice(pub Option<ToolType>);
#[derive(Deserialize)]
#[serde(untagged)]
enum ToolTypeDeserializer {
Null,
String(String),
ToolType(ToolType),
}
@ -984,10 +991,11 @@ enum ToolTypeDeserializer {
impl From<ToolTypeDeserializer> for ToolChoice {
fn from(value: ToolTypeDeserializer) -> Self {
match value {
ToolTypeDeserializer::Null => ToolChoice(None),
ToolTypeDeserializer::String(s) => match s.as_str() {
"none" => ToolChoice(Some(ToolType::NoTool)),
"auto" => ToolChoice(Some(ToolType::OneOf)),
_ => ToolChoice(Some(ToolType::FunctionName(s))),
_ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))),
},
ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)),
}

View File

@ -42,6 +42,7 @@ use hf_hub::{Cache, Repo, RepoType};
use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use pyo3::types::IntoPyDict;
use regex::Regex;
use serde_json::Value;
use std::convert::Infallible;
use std::fs::File;
@ -452,12 +453,20 @@ async fn generate_stream(
Sse<impl Stream<Item = Result<Event, Infallible>>>,
) {
let span = tracing::Span::current();
let on_message_callback = |stream_token: StreamResponse| {
let event = Event::default();
event.json_data(stream_token).unwrap()
};
let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await;
generate_stream_internal(infer, compute_type, Json(req), span).await;
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
while let Some(raw_event) = response_stream.next().await {
yield Ok(raw_event.map_or_else(Event::from, |token| {
Event::default()
.json_data(token)
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
}));
}
};
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
(headers, sse)
}
@ -466,9 +475,11 @@ async fn generate_stream_internal(
infer: Infer,
ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(StreamResponse) -> Event,
span: tracing::Span,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
) -> (
HeaderMap,
impl Stream<Item = Result<StreamResponse, InferError>>,
) {
let start_time = Instant::now();
metrics::counter!("tgi_request_count").increment(1);
@ -500,12 +511,12 @@ async fn generate_stream_internal(
let err = InferError::from(ValidationError::BestOfStream);
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
yield Ok(Event::from(err));
yield Err(err);
} else if req.parameters.decoder_input_details {
let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
yield Ok(Event::from(err));
yield Err(err);
} else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives
@ -535,8 +546,7 @@ async fn generate_stream_internal(
generated_text: None,
details: None,
};
let event = on_message_callback(stream_token);
yield Ok(event);
yield Ok(stream_token);
}
// Yield event for last token and compute timings
InferStreamResponse::End {
@ -600,9 +610,7 @@ async fn generate_stream_internal(
details
};
let event = on_message_callback(stream_token);
yield Ok(event);
yield Ok(stream_token);
break;
}
}
@ -610,7 +618,7 @@ async fn generate_stream_internal(
// yield error
Err(err) => {
error = true;
yield Ok(Event::from(err));
yield Err(err);
break;
}
}
@ -619,7 +627,7 @@ async fn generate_stream_internal(
// yield error
Err(err) => {
error = true;
yield Ok(Event::from(err));
yield Err(err);
}
}
// Check if generation reached the end
@ -628,7 +636,7 @@ async fn generate_stream_internal(
let err = InferError::IncompleteGenerationStream;
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::error!("{err}");
yield Ok(Event::from(err));
yield Err(err);
}
}
};
@ -771,7 +779,24 @@ async fn completions(
// Create a future for each generate_stream_internal call.
let generate_future = async move {
let on_message_callback = move |stream_token: StreamResponse| {
let (header_tx, header_rx) = oneshot::channel();
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
let (headers, response_stream) = generate_stream_internal(
infer_clone.clone(),
compute_type_clone.clone(),
Json(generate_request),
span_clone.clone(),
)
.await;
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
while let Some(stream_token) = response_stream.next().await {
match stream_token {
Ok(stream_token) => {
let event = Event::default();
let current_time = std::time::SystemTime::now()
@ -817,29 +842,22 @@ async fn completions(
}),
};
event
let event = event
.json_data(message)
.unwrap_or_else(|_e| Event::default())
.unwrap_or_else(|_e| Event::default());
yield Ok(event);
}
Err(err) => yield Ok(Event::from(err)),
}
}
};
let (header_tx, header_rx) = oneshot::channel();
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
let (header_map, sse) = generate_stream_internal(
infer_clone.clone(),
compute_type_clone.clone(),
Json(generate_request),
on_message_callback,
span_clone.clone(),
)
.await;
// send and dont wait for response
let _ = header_tx.send(header_map);
let _ = header_tx.send(headers);
// pin an emit messages to the sse_tx
let mut sse = Box::pin(sse);
let mut sse = Box::pin(response_stream);
while let Some(event) = sse.next().await {
if sse_tx.send(event).is_err() {
tracing::error!("Failed to send event. Receiver dropped.");
@ -1072,6 +1090,84 @@ async fn completions(
}
}
enum StreamState {
Buffering,
BufferTrailing,
Content { skip_close_quote: bool },
}
/// Convert a StreamResponse into an Event to be sent over SSE
fn create_event_from_stream_token(
stream_token: &StreamResponse,
logprobs: bool,
stream_options: Option<StreamOptions>,
inner_using_tools: bool,
system_fingerprint: String,
model_id: String,
) -> Event {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let logprobs = logprobs.then(|| {
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))
});
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if inner_using_tools {
(None, Some(vec![stream_token.token.text.clone()]))
} else {
let content = if !stream_token.token.special {
Some(stream_token.token.text.clone())
} else {
None
};
(content, None)
};
let (usage, finish_reason) = match &stream_token.details {
Some(details) => {
let usage = if stream_options
.as_ref()
.map(|s| s.include_usage)
.unwrap_or(false)
{
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Some(Usage {
completion_tokens,
prompt_tokens,
total_tokens,
})
} else {
None
};
(usage, Some(details.finish_reason.format(true)))
}
None => (None, None),
};
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
content,
tool_calls,
current_time,
logprobs,
finish_reason,
usage,
));
event.json_data(chat_complete).unwrap_or_else(|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
})
}
/// Generate tokens
#[utoipa::path(
post,
@ -1128,88 +1224,135 @@ async fn chat_completions(
// static values that will be returned in all cases
let model_id = info.model_id.clone();
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// switch on stream
if stream {
// pass this callback to the stream generation and build the required event structure
let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default();
let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
// regex to match any function name
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
Ok(regex) => regex,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Failed to compile regex: {}", e),
error_type: "regex".to_string(),
}),
))
}
};
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
let mut buffer = Vec::new();
let mut json_buffer = String::new();
let mut state = if using_tools {
StreamState::Buffering
} else {
StreamState::Content {
skip_close_quote: false,
}
};
let mut response_as_tool = using_tools;
while let Some(result) = response_stream.next().await {
if let Ok(stream_token) = result {
let token_text = &stream_token.token.text.clone();
match state {
StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", ""));
buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string();
if function_name == "no_tool" {
state = StreamState::BufferTrailing;
response_as_tool = false;
buffer.clear();
json_buffer.clear();
} else {
state = StreamState::Content {
skip_close_quote: false,
};
// send all the buffered messages
for stream_token in &buffer {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
}
}
}
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
StreamState::BufferTrailing => {
let infix_text = "\"content\":\"";
json_buffer.push_str(&token_text.replace(" ", ""));
// keep capturing until we find the infix text
match json_buffer.find(infix_text) {
Some(content_key_index) => {
json_buffer =
json_buffer[content_key_index + infix_text.len()..].to_string();
}
None => {
continue;
}
}
// if there is leftover text after removing the infix text, we need to send it
if !json_buffer.is_empty() {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let logprobs = logprobs.then(|| {
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
});
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if using_tools {
(None, Some(vec![stream_token.token.text]))
} else {
let content = if !stream_token.token.special {
Some(stream_token.token.text)
} else {
None
};
(content, None)
};
let (usage, finish_reason) = match stream_token.details {
Some(details) => {
let usage = if stream_options
.as_ref()
.map(|s| s.include_usage)
.unwrap_or(false)
{
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Some(Usage {
completion_tokens,
prompt_tokens,
total_tokens,
})
} else {
None
};
(usage, Some(details.finish_reason.format(true)))
}
None => (None, None),
};
event
.json_data(CompletionType::ChatCompletionChunk(
ChatCompletionChunk::new(
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
content,
tool_calls,
Some(json_buffer.clone()),
None,
current_time,
logprobs,
finish_reason,
usage,
),
))
.unwrap_or_else(|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
})
};
let (headers, response_stream) = generate_stream_internal(
infer,
compute_type,
Json(generate_request),
on_message_callback,
span,
)
.await;
let response_stream = response_stream.chain(futures::stream::once(async {
Ok(Event::default().data("[DONE]"))
None,
None,
None,
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
}
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
}
// send the content
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
}
}
}
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
};
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
@ -1246,7 +1389,21 @@ async fn chat_completions(
if let Value::Object(ref mut props) = arguments {
props.remove("_name");
}
match name.as_str() {
"no_tool" => {
// parse the content message
let content_message = arguments
.get("content")
.and_then(Value::as_str)
.ok_or_else(|| {
InferError::ToolError(
"No `content` found in generated text".to_string(),
)
})?
.to_string();
(None, Some(content_message))
}
_ => {
let tool_calls = vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
@ -1257,6 +1414,8 @@ async fn chat_completions(
},
}];
(Some(tool_calls), None)
}
}
} else {
(None, Some(generation.generated_text))
};
@ -1937,6 +2096,11 @@ async fn start(
metrics::Unit::Count,
"Maximum tokens for the current batch"
);
metrics::describe_gauge!(
"tgi_batch_total_tokens",
metrics::Unit::Count,
"Maximum amount of tokens in total."
);
metrics::describe_histogram!(
"tgi_request_max_new_tokens",
metrics::Unit::Count,
@ -2318,6 +2482,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::StreamSerializationError(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
(
@ -2495,8 +2660,8 @@ mod tests {
);
assert!(result.is_ok());
let (inputs, _grammar, using_tools) = result.unwrap();
let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
assert_eq!(using_tools, true);
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
}
}

View File

@ -567,6 +567,7 @@ fn image_tokens(
use HubPreprocessorConfig::*;
match config {
Idefics => "<image>".to_string(),
Mllama => "<|image|>".to_string(),
Idefics2(config) => {
const FAKE: &str = "<fake_token_around_image>";
const IMAGE: &str = "<image>";
@ -618,7 +619,7 @@ fn prepare_input(
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;

View File

@ -1,5 +1,5 @@
[toolchain]
# Released on: June 13, 2024
# https://releases.rs/docs/1.79.0/
channel = "1.80.0"
channel = "1.80.1"
components = ["rustfmt", "clippy"]

View File

@ -1,5 +1,5 @@
flash_att_v2_commit_cuda := v2.6.1
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
build-flash-attention-v2-cuda:
pip install -U packaging wheel
@ -11,7 +11,7 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
build-flash-attention-v2-rocm:
if [ ! -d 'flash-attention-v2' ]; then \
pip install -U packaging ninja --no-cache-dir && \
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \
git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
fi

View File

@ -1,5 +1,5 @@
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
build-vllm-cuda:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda
build-vllm-rocm:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \
git clone https://github.com/mht-sharma/vllm.git vllm; \
fi
cd vllm && git fetch && git checkout $(commit_rocm) && \
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build

View File

@ -1,5 +1,17 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
extra_cuda_cflags = []
extra_cflags = []
if torch.version.hip:
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
extra_cuda_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
extra_compile_args = {
"cxx": extra_cflags,
"nvcc": extra_cuda_cflags,
}
setup(
name="exllama_kernels",
@ -13,6 +25,7 @@ setup(
"exllama_kernels/cuda_func/q4_matmul.cu",
"exllama_kernels/cuda_func/q4_matrix.cu",
],
extra_compile_args=extra_compile_args,
)
],
cmdclass={"build_ext": BuildExtension},

View File

@ -3,11 +3,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
extra_cuda_cflags = ["-lineinfo", "-O3"]
extra_cflags = []
if torch.version.hip:
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
extra_cflags = ["-DLEGACY_HIPBLAS_DIRECT=ON"]
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF", "-DLEGACY_HIPBLAS_DIRECT=ON"]
extra_compile_args = {
"cxx": extra_cflags,
"nvcc": extra_cuda_cflags,
}

2583
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -23,10 +23,10 @@ opentelemetry-api = "^1.25.0"
opentelemetry-exporter-otlp = "^1.25.0"
opentelemetry-instrumentation-grpc = "^0.46b0"
hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
tokenizers = "^0.19.1"
sentencepiece = "^0.2"
tokenizers = "^0.20"
huggingface-hub = "^0.23"
transformers = "^4.43"
transformers = "^4.45"
einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true }
@ -47,10 +47,10 @@ marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
rich = "^13.7.1"
@ -82,3 +82,6 @@ requires = [
"poetry-core>=1.0.0",
]
build-backend = "poetry.core.masonry.api"
[tool.isort]
profile = "black"

View File

@ -1,19 +1,19 @@
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,19 +1,19 @@
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,19 +1,19 @@
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -30,6 +30,10 @@ class Dtype(str, Enum):
bloat16 = "bfloat16"
class KVCacheDtype(str, Enum):
fp8_e5m2 = "fp8_e5m2"
@app.command()
def serve(
model_id: str,
@ -38,6 +42,7 @@ def serve(
quantize: Optional[Quantization] = None,
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None,
kv_cache_dtype: Optional[KVCacheDtype] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
@ -97,6 +102,7 @@ def serve(
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
if dtype is not None and quantize not in {
None,
"bitsandbytes",
@ -114,6 +120,7 @@ def serve(
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
uds_path,
max_input_tokens,

View File

@ -1,37 +1,40 @@
from text_generation_server.utils.import_utils import SYSTEM
import os
from text_generation_server.utils.import_utils import SYSTEM
from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
)
elif SYSTEM == "rocm":
from .rocm import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "ipex":
from .ipex import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache
__all__ = [
"attention",
@ -39,5 +42,6 @@ __all__ = [
"reshape_and_cache",
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"KVCache",
"Seqlen",
]

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
import torch
from typing import Optional
@ -65,5 +66,7 @@ else:
max_k: int
def clamp(self, max):
raise NotImplementedError("Not implemented seqlen for paged")
return Seqlen(torch.clamp(self.input_lengths, max=max))
if SYSTEM == "rocm":
return self
self.input_lengths = torch.clamp(self.input_lengths, max=max)
return self

View File

@ -355,3 +355,11 @@ else:
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]

View File

@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
@ -90,8 +91,9 @@ def use_prefill_with_paged_kv_state(
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
q_data_type=dtype,
page_size=page_size,
window_left=window_left,
)
yield
finally:
@ -119,7 +121,8 @@ def use_prefill_state(
num_heads: int,
num_kv_heads: int,
head_size: int,
query_dtype: str = "float16",
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
@ -135,7 +138,8 @@ def use_prefill_state(
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
q_data_type=dtype,
window_left=window_left,
)
yield
finally:
@ -200,7 +204,8 @@ def use_decode_state(
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer decoding state to the given
@ -235,7 +240,9 @@ def use_decode_state(
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=page_size,
q_data_type=query_dtype,
data_type=dtype,
q_data_type=dtype,
window_left=window_left,
)
yield
finally:

View File

@ -80,3 +80,12 @@ def paged_attention(
None,
)
return out
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]

View File

@ -0,0 +1,119 @@
from typing import Tuple
import torch
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import reshape_and_cache
class KVCache:
"""
Key-value cache for attention layers.
"""
kv_cache: Tuple[torch.Tensor, torch.Tensor]
def __init__(
self,
*,
num_blocks: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
if dtype == torch.float8_e5m2 and (
ATTENTION != "flashinfer" or SYSTEM != "cuda"
):
raise ValueError(
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
)
element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "ipex" and device.type == "xpu":
x = 1
else:
x = BLOCK_SIZE // element_size
if ATTENTION in {"flashdecoding", "flashinfer"}:
self.kv_cache = (
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
elif SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = (
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype,
device=device,
),
)
else:
self.kv_cache = (
torch.zeros(
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
dtype=dtype,
device=device,
),
torch.zeros(
(num_blocks, num_heads, head_size, BLOCK_SIZE),
dtype=dtype,
device=device,
),
)
@property
def key(self):
"""Get the key cache."""
return self.kv_cache[0]
@property
def value(self):
"""Get the value cache."""
return self.kv_cache[1]
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
):
"""Store the key and value at the given slots."""
key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]
if ATTENTION in {"flashdecoding", "flashinfer"}:
# TODO: add scale
key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype)
if key_cache.dtype == torch.float8_e5m2:
# Torch index_put does not support float8_e5m2 yet, so
# put as raw data instead.
key_cache = key_cache.view(torch.uint8)
value_cache = value_cache.view(torch.uint8)
key = key.view(torch.uint8)
value = value.view(torch.uint8)
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
reshape_and_cache(key, value, key_cache, value_cache, slots)

View File

@ -1,4 +1,5 @@
import os
from typing import Optional
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
@ -8,16 +9,28 @@ from loguru import logger
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
_PARTITION_SIZE_V1V2 = 512
_PARTITION_SIZE_CUSTOM = 256
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"
PREFILL_IN_KV_CACHE = False
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
try:
from vllm._C import cache_ops
if use_rocm_custom_paged_attn:
from vllm._custom_C import paged_attention_custom
except ImportError as e:
log_master(
logger.info,
f"Custom Paged Attention not available. Complete error: {e}",
)
use_rocm_custom_paged_attn = False
try:
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
@ -36,9 +49,7 @@ def reshape_and_cache(
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
def paged_attention(
@ -48,8 +59,9 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: Seqlen,
seqlen: Seqlen,
max_s: int,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
@ -68,11 +80,31 @@ def paged_attention(
# limitations under the License.
#
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
num_kv_heads = key_cache.shape[1]
gqa_ratio = num_heads // num_kv_heads
use_custom = (
use_rocm_custom_paged_attn
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
and (head_size == 128 or head_size == 64)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_s <= 32768
)
if not use_custom:
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
else:
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = input_lengths.input_lengths
input_lengths = seqlen.input_lengths
out = torch.empty_like(query)
@ -81,9 +113,13 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
from vllm._C import ops
import vllm._custom_ops as ops
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
use_v1 = (
max_s <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
and not use_custom
)
if use_v1:
ops.paged_attention_v1(
out,
@ -115,6 +151,7 @@ def paged_attention(
)
max_logits = torch.empty_like(exp_sums)
if not use_custom:
ops.paged_attention_v2(
out,
exp_sums,
@ -133,6 +170,25 @@ def paged_attention(
"auto",
1.0,
)
else:
paged_attention_custom(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
)
return out
@ -175,13 +231,14 @@ if ENGINE == "ck":
def attention(
q,
k,
v,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: float = 0.0,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
@ -191,46 +248,57 @@ if ENGINE == "ck":
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
key_cache,
value_cache,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
None,
None,
None,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)
)[0]
elif ENGINE == "triton":
from .flash_attn_triton import triton_attention
def attention(
q,
k,
v,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
if softcap is not None:
raise NotImplementedError("softcap is only available with CK flash attn")
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention(
q,
k,
v,
key_cache,
value_cache,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
causal,
softmax_scale,
)
@ -238,3 +306,11 @@ elif ENGINE == "triton":
else:
raise RuntimeError(f"Unknown attention engine {ENGINE}")
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]

View File

@ -1,12 +1,21 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F
import os
if SYSTEM == "rocm":
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
"true",
"1",
)
if ROCM_USE_SKINNY_GEMM:
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
raise ImportError(
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
)
class FastLinear(torch.nn.Module):
@ -48,6 +57,14 @@ class FastLinearROCm(torch.nn.Module):
else:
self.bias = None
self.cu_count = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
self.use_skinny_gemm = (
ROCM_USE_SKINNY_GEMM
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
)
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight")
@ -61,7 +78,11 @@ class FastLinearROCm(torch.nn.Module):
weight = self.weight
bias = self.bias
if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1:
if (
self.use_skinny_gemm
and inp.dtype == torch.float16
and inp.shape[-1] % 8 == 0
):
batched = False
inp_shape = inp.shape
@ -69,13 +90,16 @@ class FastLinearROCm(torch.nn.Module):
inp = inp.view(-1, inp_shape[-1])
batched = True
m, k = weight.shape[0], inp_shape[1]
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
if m > 8 and n <= 4:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
elif m % 4 == 0 and n == 1 and k <= 8192:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
_custom_C.LLMM1(weight, inp, out, 8)
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
_custom_C.LLMM1(weight, inp, out, 4)
else:
out = F.linear(inp, weight)

View File

@ -43,7 +43,7 @@ def can_use_gptq_marlin(
and quant_method in {"awq", "gptq"}
and bits in GPTQ_MARLIN_BITS
and groupsize in GPTQ_MARLIN_GROUP_SIZES
# We only suppord asymmetric quantization for AWQ.
# We only support asymmetric quantization for AWQ.
and (sym or quant_method == "awq")
)
@ -109,7 +109,6 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
prefix: str,
block_sizes: Union[int, List[int]],
):
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
@ -352,7 +351,7 @@ def repack_gptq_for_marlin(
scales = permute_scales(scales)
is_full_k = not (desc_act and sharded_infeatures)
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
return GPTQMarlinWeight(
qweight=repacked,

View File

@ -10,16 +10,24 @@ from text_generation_server.layers import (
TensorParallelRowLinear,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
from text_generation_server.layers.moe.gptq_marlin import (
GPTQMarlinSparseMoELayer,
can_use_marlin_moe_gemm,
)
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
UnquantizedWeight,
)
if SYSTEM != "ipex":
if SYSTEM == "rocm":
from .fused_moe_rocm import grouped_topk
from vllm.model_executor.layers.fused_moe import fused_topk
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_topk, grouped_topk
@ -202,12 +210,22 @@ class SparseMoELayer(nn.Module):
and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader):
cls = UnquantizedSparseMoELayer
# Once we wire up GPTQ-Marlin MoE:
# elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
# cls = GPTQMarlinSparseMoELayer
elif isinstance(
weights.loader, GPTQMarlinWeightsLoader
) and can_use_marlin_moe_gemm(
quant_method=weights.loader.quant_method,
quantize=weights.loader.quantize,
sym=weights.loader.sym,
):
cls = GPTQMarlinSparseMoELayer
else:
raise ValueError(
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
)
log_once(
logger.info,
"Using MoE layer wih fused gemm",
)
self.moe = cls(
@ -234,6 +252,12 @@ class SparseMoELayer(nn.Module):
and isinstance(weights.loader.weight_class, UnquantizedWeight)
)
or isinstance(weights.loader, HybridFP8UnquantLoader)
# Once we wire up GPTQ-Marlin MoE:
# or isinstance(weights.loader, GPTQMarlinWeightsLoader)
or (
isinstance(weights.loader, GPTQMarlinWeightsLoader)
and can_use_marlin_moe_gemm(
quant_method=weights.loader.quant_method,
quantize=weights.loader.quantize,
sym=weights.loader.sym,
)
)
)

View File

@ -0,0 +1,52 @@
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.distributed
# TODO: Remove the functions once moe_kernel are built for ROCM
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids

View File

@ -0,0 +1,228 @@
from dataclasses import dataclass
from typing import List, Optional
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.marlin.gptq import (
GPTQMarlinWeight,
GPTQMarlinWeightsLoader,
)
if SYSTEM == "cuda":
from moe_kernels.fused_marlin_moe import fused_marlin_moe
else:
fused_marlin_moe = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
def can_use_marlin_moe_gemm(
*,
quant_method: str,
quantize: str,
sym: bool,
):
return (
SYSTEM == "cuda"
and fused_marlin_moe is not None
and has_sm_8_0
and quantize in {"awq", "gptq"}
and quant_method in {"awq", "gptq"}
# We only support asymmetric quantization for AWQ.
and (sym or quant_method == "awq")
)
@dataclass
class GPTQMarlinMoEWeight:
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: torch.Tensor
perm: torch.Tensor
is_full_k: bool
class GPTQMarlinSparseMoELayer(nn.Module):
"""
MoE layer that uses a fused GPTQ-Marlin kernel.
"""
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
if not (
isinstance(weights.loader, GPTQMarlinWeightsLoader)
and can_use_marlin_moe_gemm(
quant_method=weights.loader.quant_method,
quantize=weights.loader.quantize,
sym=weights.loader.sym,
)
):
raise ValueError(
f"Unsupported weights loader: {type(weights.loader)}, only GPTQMarlinWeightsLoader with AWQ and symmetric GPTQ quantization is supported"
)
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
names=[gate_proj_name, up_proj_name],
weights=weights,
)
self.down_proj = _load_expert_weights_row(
prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights
)
self.bits = weights.loader.bits
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_marlin_moe(
hidden_states=x,
w1=self.gate_up_proj.qweight,
w2=self.down_proj.qweight,
w1_scale=self.gate_up_proj.scales,
w2_scale=self.down_proj.scales,
w1_zeros=(
self.gate_up_proj.qzeros
if self.gate_up_proj.qzeros.numel() > 0
else None
),
w2_zeros=(
self.down_proj.qzeros if self.down_proj.qzeros.numel() > 0 else None
),
g_idx1=self.gate_up_proj.g_idx,
g_idx2=self.down_proj.g_idx,
sort_indices1=self.gate_up_proj.perm,
sort_indices2=self.down_proj.perm,
is_k_full=self.gate_up_proj.is_full_k or self.down_proj.is_full_k,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
num_bits=self.bits,
)
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
names: List[str],
weights: Weights,
) -> GPTQMarlinMoEWeight:
moe_weight = None
for i in range(n_experts):
weight = weights.get_multi_weights_col(
[f"{prefix}.{i}.{name}" for name in names], 0
)
assert isinstance(weight, GPTQMarlinWeight)
moe_weight = _pack_weight(
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
)
assert moe_weight is not None
return moe_weight
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> GPTQMarlinMoEWeight:
moe_weight = None
for i in range(n_experts):
weight = weights.get_weights_row(
f"{prefix}.{i}.{name}",
)
assert isinstance(weight, GPTQMarlinWeight)
moe_weight = _pack_weight(
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
)
assert moe_weight is not None
return moe_weight
def _pack_weight(
*,
n_experts: int,
expert: int,
moe_weight: Optional[GPTQMarlinMoEWeight],
weight: GPTQMarlinWeight,
) -> GPTQMarlinMoEWeight:
if moe_weight is None:
qweight = torch.empty(
(n_experts,) + weight.qweight.shape,
dtype=weight.qweight.dtype,
device=weight.qweight.device,
)
qzeros = torch.empty(
(n_experts,) + weight.qzeros.shape,
dtype=weight.qzeros.dtype,
device=weight.qzeros.device,
)
scales = torch.empty(
(n_experts,) + weight.scales.shape,
dtype=weight.scales.dtype,
device=weight.scales.device,
)
g_idx = torch.empty(
(n_experts,) + weight.g_idx.shape,
dtype=weight.g_idx.dtype,
device=weight.g_idx.device,
)
perm = torch.empty(
(n_experts,) + weight.perm.shape,
dtype=weight.perm.dtype,
device=weight.perm.device,
)
moe_weight = GPTQMarlinMoEWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
perm=perm,
is_full_k=weight.is_full_k,
)
moe_weight.qweight[expert] = weight.qweight
moe_weight.qzeros[expert] = weight.qzeros
moe_weight.scales[expert] = weight.scales
moe_weight.g_idx[expert] = weight.g_idx
moe_weight.perm[expert] = weight.perm
return moe_weight

View File

@ -6,7 +6,9 @@ import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import UnquantizedWeight, Weights
if SYSTEM != "ipex":
if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_moe
@ -52,6 +54,17 @@ class UnquantizedSparseMoELayer(nn.Module):
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
if SYSTEM == "rocm":
return fused_moe(
x,
self.gate_up_proj,
self.down_proj,
gating_output,
self.topk,
renormalize=self.renormalize,
inplace=True,
)
return fused_moe(
x,
w1=self.gate_up_proj,

View File

@ -166,6 +166,20 @@ class PositionRotaryEmbedding(nn.Module):
1 + math.log(scale) / math.log(original_max_position_embeddings)
)
# if short_mscale and long_mscale are provided we need to scale the freqs
# using the Phi3LongRoPEScaledRotaryEmbedding
if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
short_mscale = rope_scaling["short_mscale"]
long_mscale = rope_scaling["long_mscale"]
return Phi3LongRoPEScaledRotaryEmbedding(
short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq,
max_position_embeddings=config.max_position_embeddings,
short_mscale=short_mscale,
long_mscale=long_mscale,
original_max_position_embeddings=original_max_position_embeddings,
)
return SuRotaryEmbedding(
short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq,
@ -287,6 +301,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
@ -308,6 +323,63 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
def __init__(
self,
short_inv_freq: torch.Tensor,
long_inv_freq: torch.Tensor,
max_position_embeddings: int,
short_mscale: float,
long_mscale: float,
original_max_position_embeddings: int,
):
super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq
self.long_inv_freq = long_inv_freq
self.max_position_embeddings = max_position_embeddings
self.short_mscale = short_mscale
self.long_mscale = long_mscale
self.original_max_position_embeddings = original_max_position_embeddings
# cache
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
def _update_cos_sin_cache(self, dtype, device, seqlen):
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
short_freqs = torch.outer(
t[: self.original_max_position_embeddings],
self.short_inv_freq.to(device=t.device),
)
long_freqs = torch.outer(
t[self.original_max_position_embeddings :],
self.long_inv_freq.to(device=t.device),
)
short_freqs = short_freqs * self.short_mscale
long_freqs = long_freqs * self.long_mscale
freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)
freqs[: self.original_max_position_embeddings] = short_freqs
freqs[self.original_max_position_embeddings :] = long_freqs
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device)
@ -467,7 +539,6 @@ def apply_llama3_scaling(
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scaling_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor

View File

@ -32,6 +32,9 @@ from text_generation_server.models.custom_modeling.phi_modeling import (
PhiConfig,
PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
PhiMoEConfig,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration,
)
@ -73,6 +76,7 @@ FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM,
DeepseekV2Config,
@ -109,7 +113,11 @@ try:
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
@ -146,7 +154,7 @@ except ImportError as e:
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
__all__.append(IDEFICSSharded)
__all__.append(IdeficsCausalLM)
MAMBA_AVAILABLE = True
try:
@ -237,6 +245,11 @@ class ModelType(enum.Enum):
"name": "Phi",
"url": "https://huggingface.co/microsoft/phi-1_5",
}
PHI_MOE = {
"type": "phimoe",
"name": "PhiMoe",
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
}
BAICHUAN = {
"type": "baichuan",
"name": "Baichuan",
@ -308,6 +321,12 @@ class ModelType(enum.Enum):
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
"multimodal": True,
}
MLLAMA = {
"type": "mllama",
"name": "Mllama",
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
"multimodal": True,
}
__GLOBALS = locals()
@ -323,6 +342,7 @@ def get_model(
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str],
kv_cache_dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
) -> Model:
@ -384,6 +404,13 @@ def get_model(
else:
raise RuntimeError(f"Unknown dtype {dtype}")
if kv_cache_dtype is None:
kv_cache_dtype = dtype
elif kv_cache_dtype == "fp8_e5m2":
kv_cache_dtype = torch.float8_e5m2
else:
raise RuntimeError(f"Unknown kv_cache_dtype: {kv_cache_dtype}")
if speculate is not None:
set_speculate(speculate)
else:
@ -544,6 +571,7 @@ def get_model(
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV2Config,
@ -598,6 +626,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
@ -649,6 +678,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
@ -684,6 +714,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
@ -722,6 +753,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig,
@ -755,6 +787,31 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == PHI_MOE:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
config_class=PhiMoEConfig,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
@ -794,6 +851,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
@ -817,6 +875,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
@ -842,6 +901,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
@ -868,6 +928,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
@ -892,6 +953,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Dbrx works better in bfloat16.
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
@ -922,6 +984,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
@ -940,6 +1003,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
@ -967,6 +1031,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
@ -991,6 +1056,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
@ -1015,6 +1081,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
@ -1041,6 +1108,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
@ -1085,7 +1153,7 @@ def get_model(
)
if model_type == IDEFICS:
if FLASH_ATTENTION:
return IDEFICSSharded(
return IdeficsCausalLM(
model_id,
revision,
quantize=quantize,
@ -1095,6 +1163,22 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == MLLAMA:
if FLASH_ATTENTION:
return MllamaCausalLM(
model_id=model_id,
model_class=MllamaForConditionalGeneration,
batch_class=MllamaCausalLMBatch,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
if model_type == IDEFICS2:
if FLASH_ATTENTION:
return VlmCausalLM(
@ -1104,6 +1188,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
@ -1121,6 +1206,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
@ -1139,6 +1225,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
)
else:
@ -1211,6 +1298,7 @@ def get_model_with_lora_adapters(
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str],
kv_cache_dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
adapter_to_index: Dict[str, int],
@ -1224,6 +1312,7 @@ def get_model_with_lora_adapters(
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
max_input_tokens,
)

View File

@ -517,11 +517,10 @@ class CausalLM(Model):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
elif SYSTEM == "ipex":
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
@ -593,8 +592,14 @@ class CausalLM(Model):
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
device_count = 0
if torch.cuda.is_available():
device = torch.device("cuda")
device_count = torch.cuda.device_count()
dtype = torch.float16 if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
device_count = torch.xpu.device_count()
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
@ -614,20 +619,12 @@ class CausalLM(Model):
model_id,
revision=revision,
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
device_map=("auto" if device_count > 1 else None),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if (
torch.cuda.is_available()
and torch.cuda.device_count() == 1
and quantize != "bitsandbytes"
):
model = model.cuda()
if device_count == 1 and quantize != "bitsandbytes":
model = model.to(device)
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:

View File

@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.utils.import_utils import SYSTEM
@ -291,15 +290,15 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
kv_cache.store(key=key, value=value, slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
@ -308,8 +307,8 @@ class FlashCohereAttention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -28,7 +28,6 @@ if SYSTEM != "ipex":
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
PREFILL_IN_KV_CACHE,
)
@ -330,15 +329,15 @@ class DbrxAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -347,8 +346,8 @@ class DbrxAttention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
reshape_and_cache,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm
@ -321,15 +320,15 @@ class DeepseekV2Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0
)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
kv_cache.store(key=key, value=value, slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
@ -338,8 +337,8 @@ class DeepseekV2Attention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
@ -390,6 +389,7 @@ class DeepseekV2MLP(nn.Module):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.dtype == torch.float16
and hidden_states.shape[0] == 1
and not self.quantize
):

View File

@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
@ -253,15 +252,15 @@ class FlashGemma2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -273,8 +272,8 @@ class FlashGemma2Attention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
PREFILL_IN_KV_CACHE,
)
@ -224,15 +223,15 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -242,8 +241,8 @@ class FlashGemmaAttention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache.key,
kv_cache.value,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

Some files were not shown because too many files have changed in this diff Show More